Compare commits

..

2 Commits

Author SHA1 Message Date
Parth Sareen
6b2abfb433 server: add tests and fix isHuggingFaceURL edge case
- Add comprehensive tests for isHuggingFaceURL and getNumDownloadParts
- Fix bug where domains ending in huggingface.co (like nothuggingface.co)
  would incorrectly match as HuggingFace URLs
- Improve code comments with more detailed documentation
2026-01-18 16:45:17 -08:00
Parth Sareen
805ed4644c server: reduce download concurrency for HuggingFace URLs
Reduces concurrent download parts from 16 to 4 for HuggingFace URLs
to avoid triggering rate limits (HTTP 429 errors).

Adds OLLAMA_HF_CONCURRENCY environment variable for users who want
to customize the concurrency level.

Fixes #13297
2026-01-18 16:38:49 -08:00
89 changed files with 9170 additions and 994 deletions

View File

@@ -190,7 +190,7 @@ if(MLX_ENGINE)
install(TARGETS mlx mlxc install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library ## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library") Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Here are some example models that can be downloaded: Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` | | Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` | | Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE] > [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,38 +260,6 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2 ./ollama run llama3.2
``` ```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
```shell
go build -tags mlx -o ollama-mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API ## REST API
Ollama has a REST API for running and managing models. Ollama has a REST API for running and managing models.
@@ -322,7 +290,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop ### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui) - [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat) - [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) - [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
@@ -454,7 +421,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable) - [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers) - [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI) - [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.) - [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.) - [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.) - [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.) - [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
@@ -526,7 +493,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database ### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector) - [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md) - [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps) - [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama) - [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases) - [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -669,7 +636,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability ### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama. - [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing. - [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. - [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -678,5 +644,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications. - [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security ### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server) - [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -14,7 +14,6 @@ extern NSString *SystemWidePath;
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate> @interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
@property(strong, nonatomic) NSStatusItem *statusItem; @property(strong, nonatomic) NSStatusItem *statusItem;
@property(assign, nonatomic) BOOL updateAvailable; @property(assign, nonatomic) BOOL updateAvailable;
@property(assign, nonatomic) BOOL systemShutdownInProgress;
@end @end
@implementation AppDelegate @implementation AppDelegate
@@ -41,13 +40,6 @@ bool firstTimeRun,startHidden; // Set in run before initialization
} }
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification { - (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
// Register for system shutdown/restart notification so we can allow termination
[[[NSWorkspace sharedWorkspace] notificationCenter]
addObserver:self
selector:@selector(systemWillPowerOff:)
name:NSWorkspaceWillPowerOffNotification
object:nil];
// if we're in development mode, set the app icon // if we're in development mode, set the app icon
NSString *bundlePath = [[NSBundle mainBundle] bundlePath]; NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
if (![bundlePath hasSuffix:@".app"]) { if (![bundlePath hasSuffix:@".app"]) {
@@ -286,18 +278,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
[NSApp activateIgnoringOtherApps:YES]; [NSApp activateIgnoringOtherApps:YES];
} }
- (void)systemWillPowerOff:(NSNotification *)notification {
// Set flag so applicationShouldTerminate: knows to allow termination.
// The system will call applicationShouldTerminate: after posting this notification.
self.systemShutdownInProgress = YES;
}
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender { - (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
// Allow termination if the system is shutting down or restarting
if (self.systemShutdownInProgress) {
return NSTerminateNow;
}
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
[NSApp hide:nil]; [NSApp hide:nil];
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory]; [NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
return NSTerminateCancel; return NSTerminateCancel;

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
Placeholder: "Send a message (/? for help)", Placeholder: "Send a message (/? for help)",
AltPlaceholder: "Press Enter to send", AltPlaceholder: `Use """ to end multi-line input`,
}) })
if err != nil { if err != nil {
return err return err

View File

@@ -21,7 +21,6 @@ ollama pull glm-4.7:cloud
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables: To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored export ANTHROPIC_API_KEY=ollama # required but ignored
``` ```
@@ -248,13 +247,12 @@ curl -X POST http://localhost:11434/v1/messages \
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend: [Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell ```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
``` ```
Or set the environment variables in your shell profile: Or set the environment variables in your shell profile:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama export ANTHROPIC_API_KEY=ollama
``` ```

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama"; import { Ollama } from "ollama";
const client = new Ollama(); const client = new Ollama();
const results = await client.webSearch("what is ollama?"); const results = await client.webSearch({ query: "what is ollama?" });
console.log(JSON.stringify(results, null, 2)); console.log(JSON.stringify(results, null, 2));
``` ```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama"; import { Ollama } from "ollama";
const client = new Ollama(); const client = new Ollama();
const fetchResult = await client.webFetch("https://ollama.com"); const fetchResult = await client.webFetch({ url: "https://ollama.com" });
console.log(JSON.stringify(fetchResult, null, 2)); console.log(JSON.stringify(fetchResult, null, 2));
``` ```

View File

@@ -111,9 +111,7 @@
"/integrations/zed", "/integrations/zed",
"/integrations/roo-code", "/integrations/roo-code",
"/integrations/n8n", "/integrations/n8n",
"/integrations/xcode", "/integrations/xcode"
"/integrations/onyx",
"/integrations/marimo"
] ]
}, },
{ {

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size? ## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens. By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use: This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 174 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 186 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 300 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -25,7 +25,6 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables: 1. Set the environment variables:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama export ANTHROPIC_API_KEY=ollama
``` ```
@@ -39,7 +38,7 @@ claude --model qwen3-coder
Or run with environment variables inline: Or run with environment variables inline:
```shell ```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
``` ```
## Connecting to ollama.com ## Connecting to ollama.com

View File

@@ -1,73 +0,0 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -1,63 +0,0 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -1,5 +1,5 @@
--- ---
title: Linux title: "Linux"
--- ---
## Install ## Install
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install ## Manual install
<Note> <Note>
If you are upgrading from a prior version, you should remove the old libraries If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
with `sudo rm -rf /usr/lib/ollama` first.
</Note> </Note>
Download and extract the package: Download and extract the package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \ curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar x -C /usr | sudo tar zx -C /usr
``` ```
Start Ollama: Start Ollama:
@@ -41,8 +40,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package: If you have an AMD GPU, also download and extract the additional ROCm package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
| sudo tar x -C /usr | sudo tar zx -C /usr
``` ```
### ARM64 install ### ARM64 install
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
Download and extract the ARM64-specific package: Download and extract the ARM64-specific package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \ curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
| sudo tar x -C /usr | sudo tar zx -C /usr
``` ```
### Adding Ollama as a startup service (recommended) ### Adding Ollama as a startup service (recommended)
@@ -113,11 +112,7 @@ sudo systemctl status ollama
``` ```
<Note> <Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note> </Note>
## Customizing ## Customizing
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama: Or by re-downloading Ollama:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \ curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar x -C /usr | sudo tar zx -C /usr
``` ```
## Installing specific versions ## Installing specific versions
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama sudo userdel ollama
sudo groupdel ollama sudo groupdel ollama
sudo rm -r /usr/share/ollama sudo rm -r /usr/share/ollama
``` ```

View File

@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather") t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
} }
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok { if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String()) t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
} }
case <-ctx.Done(): case <-ctx.Done():

View File

@@ -1464,11 +1464,6 @@ type CompletionRequest struct {
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20) // TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int TopLogprobs int
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Seed int64 `json:"seed,omitempty"`
} }
// DoneReason represents the reason why a completion response is done // DoneReason represents the reason why a completion response is done
@@ -1517,11 +1512,6 @@ type CompletionResponse struct {
// Logprobs contains log probability information if requested // Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"` Logprobs []Logprob `json:"logprobs,omitempty"`
// Image generation fields
Image []byte `json:"image,omitempty"` // Generated image
Step int `json:"step,omitempty"` // Current generation step
Total int `json:"total,omitempty"` // Total generation steps
} }
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {

View File

@@ -8,7 +8,6 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -442,7 +441,6 @@ type ResponsesWriter struct {
stream bool stream bool
responseID string responseID string
itemID string itemID string
request openai.ResponsesRequest
} }
func (w *ResponsesWriter) writeEvent(eventType string, data any) error { func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
@@ -480,9 +478,7 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
// Non-streaming response // Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json") w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request) response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
completedAt := time.Now().Unix()
response.CompletedAt = &completedAt
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response) return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
} }
@@ -527,12 +523,11 @@ func ResponsesMiddleware() gin.HandlerFunc {
w := &ResponsesWriter{ w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer}, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req), converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
model: req.Model, model: req.Model,
stream: streamRequested, stream: streamRequested,
responseID: responseID, responseID: responseID,
itemID: itemID, itemID: itemID,
request: req,
} }
// Set headers based on streaming mode // Set headers based on streaming mode

View File

@@ -630,10 +630,6 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
// decodeImageURL decodes a base64 data URI into raw image bytes. // decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) { func decodeImageURL(url string) (api.ImageData, error) {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
}
types := []string{"jpeg", "jpg", "png", "webp"} types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64 // Support blank mime type to match /api/chat's behavior of taking just unadorned base64

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand" "math/rand"
"time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -266,9 +265,9 @@ type ResponsesText struct {
type ResponsesTool struct { type ResponsesTool struct {
Type string `json:"type"` // "function" Type string `json:"type"` // "function"
Name string `json:"name"` Name string `json:"name"`
Description *string `json:"description"` // nullable but required Description string `json:"description,omitempty"`
Strict *bool `json:"strict"` // nullable but required Strict bool `json:"strict,omitempty"`
Parameters map[string]any `json:"parameters"` // nullable but required Parameters map[string]any `json:"parameters,omitempty"`
} }
type ResponsesRequest struct { type ResponsesRequest struct {
@@ -476,16 +475,11 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
} }
} }
var description string
if t.Description != nil {
description = *t.Description
}
return api.Tool{ return api.Tool{
Type: t.Type, Type: t.Type,
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: t.Name, Name: t.Name,
Description: description, Description: t.Description,
Parameters: params, Parameters: params,
}, },
}, nil }, nil
@@ -522,60 +516,17 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
// Response types for the Responses API // Response types for the Responses API
// ResponsesTextField represents the text output configuration in the response.
type ResponsesTextField struct {
Format ResponsesTextFormat `json:"format"`
}
// ResponsesReasoningOutput represents reasoning configuration in the response.
type ResponsesReasoningOutput struct {
Effort *string `json:"effort,omitempty"`
Summary *string `json:"summary,omitempty"`
}
// ResponsesError represents an error in the response.
type ResponsesError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// ResponsesIncompleteDetails represents details about why a response was incomplete.
type ResponsesIncompleteDetails struct {
Reason string `json:"reason"`
}
type ResponsesResponse struct { type ResponsesResponse struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
CompletedAt *int64 `json:"completed_at"` Status string `json:"status"`
Status string `json:"status"` Model string `json:"model"`
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"` Output []ResponsesOutputItem `json:"output"`
Model string `json:"model"` Usage *ResponsesUsage `json:"usage,omitempty"`
PreviousResponseID *string `json:"previous_response_id"` // TODO(drifkin): add `temperature` and `top_p` to the response, but this
Instructions *string `json:"instructions"` // requires additional plumbing to find the effective values since the
Output []ResponsesOutputItem `json:"output"` // defaults can come from the model or the request
Error *ResponsesError `json:"error"`
Tools []ResponsesTool `json:"tools"`
ToolChoice any `json:"tool_choice"`
Truncation string `json:"truncation"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
Text ResponsesTextField `json:"text"`
TopP float64 `json:"top_p"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
TopLogprobs int `json:"top_logprobs"`
Temperature float64 `json:"temperature"`
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
Usage *ResponsesUsage `json:"usage"`
MaxOutputTokens *int `json:"max_output_tokens"`
MaxToolCalls *int `json:"max_tool_calls"`
Store bool `json:"store"`
Background bool `json:"background"`
ServiceTier string `json:"service_tier"`
Metadata map[string]any `json:"metadata"`
SafetyIdentifier *string `json:"safety_identifier"`
PromptCacheKey *string `json:"prompt_cache_key"`
} }
type ResponsesOutputItem struct { type ResponsesOutputItem struct {
@@ -599,39 +550,18 @@ type ResponsesReasoningSummary struct {
} }
type ResponsesOutputContent struct { type ResponsesOutputContent struct {
Type string `json:"type"` // "output_text" Type string `json:"type"` // "output_text"
Text string `json:"text"` Text string `json:"text"`
Annotations []any `json:"annotations"`
Logprobs []any `json:"logprobs"`
}
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
}
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
} }
type ResponsesUsage struct { type ResponsesUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
} }
// derefFloat64 returns the value of a float64 pointer, or a default if nil. // ToResponse converts an api.ChatResponse to a Responses API response
func derefFloat64(p *float64, def float64) float64 { func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
if p != nil {
return *p
}
return def
}
// ToResponse converts an api.ChatResponse to a Responses API response.
// The request is used to echo back request parameters in the response.
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
var output []ResponsesOutputItem var output []ResponsesOutputItem
// Add reasoning item if thinking is present // Add reasoning item if thinking is present
@@ -655,7 +585,6 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
output = append(output, ResponsesOutputItem{ output = append(output, ResponsesOutputItem{
ID: fmt.Sprintf("fc_%s_%d", responseID, i), ID: fmt.Sprintf("fc_%s_%d", responseID, i),
Type: "function_call", Type: "function_call",
Status: "completed",
CallID: tc.ID, CallID: tc.ID,
Name: tc.Function.Name, Name: tc.Function.Name,
Arguments: tc.Function.Arguments, Arguments: tc.Function.Arguments,
@@ -669,90 +598,25 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
Role: "assistant", Role: "assistant",
Content: []ResponsesOutputContent{ Content: []ResponsesOutputContent{
{ {
Type: "output_text", Type: "output_text",
Text: chatResponse.Message.Content, Text: chatResponse.Message.Content,
Annotations: []any{},
Logprobs: []any{},
}, },
}, },
}) })
} }
var instructions *string
if request.Instructions != "" {
instructions = &request.Instructions
}
// Build truncation with default
truncation := "disabled"
if request.Truncation != nil {
truncation = *request.Truncation
}
tools := request.Tools
if tools == nil {
tools = []ResponsesTool{}
}
text := ResponsesTextField{
Format: ResponsesTextFormat{Type: "text"},
}
if request.Text != nil && request.Text.Format != nil {
text.Format = *request.Text.Format
}
// Build reasoning output from request
var reasoning *ResponsesReasoningOutput
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
reasoning = &ResponsesReasoningOutput{}
if request.Reasoning.Effort != "" {
reasoning.Effort = &request.Reasoning.Effort
}
if request.Reasoning.Summary != "" {
reasoning.Summary = &request.Reasoning.Summary
}
}
return ResponsesResponse{ return ResponsesResponse{
ID: responseID, ID: responseID,
Object: "response", Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(), CreatedAt: chatResponse.CreatedAt.Unix(),
CompletedAt: nil, // Set by middleware when writing final response Status: "completed",
Status: "completed", Model: model,
IncompleteDetails: nil, // Only populated if response incomplete Output: output,
Model: model,
PreviousResponseID: nil, // Not supported
Instructions: instructions,
Output: output,
Error: nil, // Only populated on failure
Tools: tools,
ToolChoice: "auto", // Default value
Truncation: truncation,
ParallelToolCalls: true, // Default value
Text: text,
TopP: derefFloat64(request.TopP, 1.0),
PresencePenalty: 0, // Default value
FrequencyPenalty: 0, // Default value
TopLogprobs: 0, // Default value
Temperature: derefFloat64(request.Temperature, 1.0),
Reasoning: reasoning,
Usage: &ResponsesUsage{ Usage: &ResponsesUsage{
InputTokens: chatResponse.PromptEvalCount, InputTokens: chatResponse.PromptEvalCount,
OutputTokens: chatResponse.EvalCount, OutputTokens: chatResponse.EvalCount,
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount, TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
// TODO(drifkin): wire through the actual values
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
// TODO(drifkin): wire through the actual values
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
}, },
MaxOutputTokens: request.MaxOutputTokens,
MaxToolCalls: nil, // Not supported
Store: false, // We don't store responses
Background: request.Background,
ServiceTier: "default", // Default value
Metadata: map[string]any{},
SafetyIdentifier: nil, // Not supported
PromptCacheKey: nil, // Not supported
} }
} }
@@ -772,7 +636,6 @@ type ResponsesStreamConverter struct {
responseID string responseID string
itemID string itemID string
model string model string
request ResponsesRequest
// State tracking (mutated across Process calls) // State tracking (mutated across Process calls)
firstWrite bool firstWrite bool
@@ -805,12 +668,11 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
} }
// NewResponsesStreamConverter creates a new converter with the given configuration. // NewResponsesStreamConverter creates a new converter with the given configuration.
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter { func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
return &ResponsesStreamConverter{ return &ResponsesStreamConverter{
responseID: responseID, responseID: responseID,
itemID: itemID, itemID: itemID,
model: model, model: model,
request: request,
firstWrite: true, firstWrite: true,
} }
} }
@@ -855,120 +717,25 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
return events return events
} }
// buildResponseObject creates a full response object with all required fields for streaming events.
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
var instructions any = nil
if c.request.Instructions != "" {
instructions = c.request.Instructions
}
truncation := "disabled"
if c.request.Truncation != nil {
truncation = *c.request.Truncation
}
var tools []any
if c.request.Tools != nil {
for _, t := range c.request.Tools {
tools = append(tools, map[string]any{
"type": t.Type,
"name": t.Name,
"description": t.Description,
"strict": t.Strict,
"parameters": t.Parameters,
})
}
}
if tools == nil {
tools = []any{}
}
textFormat := map[string]any{"type": "text"}
if c.request.Text != nil && c.request.Text.Format != nil {
textFormat = map[string]any{
"type": c.request.Text.Format.Type,
}
if c.request.Text.Format.Name != "" {
textFormat["name"] = c.request.Text.Format.Name
}
if c.request.Text.Format.Schema != nil {
textFormat["schema"] = c.request.Text.Format.Schema
}
if c.request.Text.Format.Strict != nil {
textFormat["strict"] = *c.request.Text.Format.Strict
}
}
var reasoning any = nil
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
r := map[string]any{}
if c.request.Reasoning.Effort != "" {
r["effort"] = c.request.Reasoning.Effort
} else {
r["effort"] = nil
}
if c.request.Reasoning.Summary != "" {
r["summary"] = c.request.Reasoning.Summary
} else {
r["summary"] = nil
}
reasoning = r
}
// Build top_p and temperature with defaults
topP := 1.0
if c.request.TopP != nil {
topP = *c.request.TopP
}
temperature := 1.0
if c.request.Temperature != nil {
temperature = *c.request.Temperature
}
return map[string]any{
"id": c.responseID,
"object": "response",
"created_at": time.Now().Unix(),
"completed_at": nil,
"status": status,
"incomplete_details": nil,
"model": c.model,
"previous_response_id": nil,
"instructions": instructions,
"output": output,
"error": nil,
"tools": tools,
"tool_choice": "auto",
"truncation": truncation,
"parallel_tool_calls": true,
"text": map[string]any{"format": textFormat},
"top_p": topP,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_logprobs": 0,
"temperature": temperature,
"reasoning": reasoning,
"usage": usage,
"max_output_tokens": c.request.MaxOutputTokens,
"max_tool_calls": nil,
"store": false,
"background": c.request.Background,
"service_tier": "default",
"metadata": map[string]any{},
"safety_identifier": nil,
"prompt_cache_key": nil,
}
}
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent { func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
return c.newEvent("response.created", map[string]any{ return c.newEvent("response.created", map[string]any{
"response": c.buildResponseObject("in_progress", []any{}, nil), "response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
}) })
} }
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent { func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
return c.newEvent("response.in_progress", map[string]any{ return c.newEvent("response.in_progress", map[string]any{
"response": c.buildResponseObject("in_progress", []any{}, nil), "response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
}) })
} }
@@ -995,10 +762,9 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
// Emit delta // Emit delta
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{ events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
"item_id": c.reasoningItemID, "item_id": c.reasoningItemID,
"output_index": c.outputIndex, "output_index": c.outputIndex,
"summary_index": 0, "delta": thinking,
"delta": thinking,
})) }))
// TODO(drifkin): consider adding // TODO(drifkin): consider adding
@@ -1017,10 +783,9 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
events := []ResponsesStreamEvent{ events := []ResponsesStreamEvent{
c.newEvent("response.reasoning_summary_text.done", map[string]any{ c.newEvent("response.reasoning_summary_text.done", map[string]any{
"item_id": c.reasoningItemID, "item_id": c.reasoningItemID,
"output_index": c.outputIndex, "output_index": c.outputIndex,
"summary_index": 0, "text": c.accumulatedThinking,
"text": c.accumulatedThinking,
}), }),
c.newEvent("response.output_item.done", map[string]any{ c.newEvent("response.output_item.done", map[string]any{
"output_index": c.outputIndex, "output_index": c.outputIndex,
@@ -1133,10 +898,8 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": c.contentIndex, "content_index": c.contentIndex,
"part": map[string]any{ "part": map[string]any{
"type": "output_text", "type": "output_text",
"text": "", "text": "",
"annotations": []any{},
"logprobs": []any{},
}, },
})) }))
} }
@@ -1150,7 +913,6 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"delta": content, "delta": content,
"logprobs": []any{},
})) }))
return events return events
@@ -1182,10 +944,8 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
"status": "completed", "status": "completed",
"role": "assistant", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}}, }},
}) })
} }
@@ -1207,7 +967,6 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"text": c.accumulatedText, "text": c.accumulatedText,
"logprobs": []any{},
})) }))
// response.content_part.done // response.content_part.done
@@ -1216,10 +975,8 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"part": map[string]any{ "part": map[string]any{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}, },
})) }))
@@ -1232,31 +989,26 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"status": "completed", "status": "completed",
"role": "assistant", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}}, }},
}, },
})) }))
} }
// response.completed // response.completed
usage := map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
"output_tokens_details": map[string]any{
"reasoning_tokens": 0,
},
}
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
response["completed_at"] = time.Now().Unix()
events = append(events, c.newEvent("response.completed", map[string]any{ events = append(events, c.newEvent("response.completed", map[string]any{
"response": response, "response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "completed",
"output": c.buildFinalOutput(),
"usage": map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
},
},
})) }))
return events return events

View File

@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
} }
func TestResponsesStreamConverter_TextOnly(t *testing.T) { func TestResponsesStreamConverter_TextOnly(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk with content // First chunk with content
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
} }
func TestResponsesStreamConverter_ToolCalls(t *testing.T) { func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{ Message: api.Message{
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
} }
func TestResponsesStreamConverter_Reasoning(t *testing.T) { func TestResponsesStreamConverter_Reasoning(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk with thinking // First chunk with thinking
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
Content: "The answer is 42", Content: "The answer is 42",
}, },
Done: true, Done: true,
}, ResponsesRequest{}) })
// Should have 2 output items: reasoning + message // Should have 2 output items: reasoning + message
if len(response.Output) != 2 { if len(response.Output) != 2 {
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) { func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
// Verify that response.output_item.done includes content field for messages // Verify that response.output_item.done includes content field for messages
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk // First chunk
converter.Process(api.ChatResponse{ converter.Process(api.ChatResponse{
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) { func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
// Verify that response.completed includes the output array // Verify that response.completed includes the output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// Process some content // Process some content
converter.Process(api.ChatResponse{ converter.Process(api.ChatResponse{
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) { func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
// Verify that response.created includes an empty output array // Verify that response.created includes an empty output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hi"}, Message: api.Message{Content: "Hi"},
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) { func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
// Verify that events include incrementing sequence numbers // Verify that events include incrementing sequence numbers
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hello"}, Message: api.Message{Content: "Hello"},
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) { func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
// Verify that function call items include status field // Verify that function call items include status field
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{ Message: api.Message{

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
) )
type Prompt struct { type Prompt struct {
@@ -37,11 +36,10 @@ type Terminal struct {
} }
type Instance struct { type Instance struct {
Prompt *Prompt Prompt *Prompt
Terminal *Terminal Terminal *Terminal
History *History History *History
Pasting bool Pasting bool
pastedLines []string
} }
func New(prompt Prompt) (*Instance, error) { func New(prompt Prompt) (*Instance, error) {
@@ -176,8 +174,6 @@ func (i *Instance) Readline() (string, error) {
case CharEsc: case CharEsc:
esc = true esc = true
case CharInterrupt: case CharInterrupt:
i.pastedLines = nil
i.Prompt.UseAlt = false
return "", ErrInterrupt return "", ErrInterrupt
case CharPrev: case CharPrev:
i.historyPrev(buf, &currentLineBuf) i.historyPrev(buf, &currentLineBuf)
@@ -192,23 +188,7 @@ func (i *Instance) Readline() (string, error) {
case CharForward: case CharForward:
buf.MoveRight() buf.MoveRight()
case CharBackspace, CharCtrlH: case CharBackspace, CharCtrlH:
if buf.IsEmpty() && len(i.pastedLines) > 0 { buf.Remove()
lastIdx := len(i.pastedLines) - 1
prevLine := i.pastedLines[lastIdx]
i.pastedLines = i.pastedLines[:lastIdx]
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
if len(i.pastedLines) == 0 {
fmt.Print(i.Prompt.Prompt)
i.Prompt.UseAlt = false
} else {
fmt.Print(i.Prompt.AltPrompt)
}
for _, r := range prevLine {
buf.Add(r)
}
} else {
buf.Remove()
}
case CharTab: case CharTab:
// todo: convert back to real tabs // todo: convert back to real tabs
for range 8 { for range 8 {
@@ -231,28 +211,13 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ: case CharCtrlZ:
fd := os.Stdin.Fd() fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios) return handleCharCtrlZ(fd, i.Terminal.termios)
case CharCtrlJ: case CharEnter, CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
case CharEnter:
output := buf.String() output := buf.String()
if len(i.pastedLines) > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
if output != "" { if output != "" {
i.History.Add(output) i.History.Add(output)
} }
buf.MoveToEnd() buf.MoveToEnd()
fmt.Println() fmt.Println()
i.Prompt.UseAlt = false
return output, nil return output, nil
default: default:

View File

@@ -179,7 +179,7 @@ _build_macapp() {
fi fi
rm -f dist/Ollama-darwin.zip rm -f dist/Ollama-darwin.zip
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz (cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple # Notarize and Staple
@@ -187,7 +187,7 @@ _build_macapp() {
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID" $(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app $(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
rm -f dist/Ollama.dmg rm -f dist/Ollama.dmg

View File

@@ -50,17 +50,12 @@ func (r registryChallenge) URL() (*url.URL, error) {
return redirectURL, nil return redirectURL, nil
} }
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) { func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
redirectURL, err := challenge.URL() redirectURL, err := challenge.URL()
if err != nil { if err != nil {
return "", err return "", err
} }
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
if redirectURL.Host != originalHost {
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
}
sha256sum := sha256.Sum256(nil) sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))) data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))

View File

@@ -1,113 +0,0 @@
package server
import (
"context"
"strings"
"testing"
"time"
)
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
tests := []struct {
realm string
originalHost string
wantMismatch bool
}{
{"https://example.com/token", "example.com", false},
{"https://example.com/token", "other.com", true},
{"https://example.com/token", "localhost:8000", true},
{"https://localhost:5000/token", "localhost:5000", false},
{"https://localhost:5000/token", "localhost:6000", true},
}
for _, tt := range tests {
t.Run(tt.originalHost, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
if tt.wantMismatch && !isMismatch {
t.Errorf("expected domain mismatch error, got: %v", err)
}
if !tt.wantMismatch && isMismatch {
t.Errorf("unexpected domain mismatch error: %v", err)
}
})
}
}
func TestParseRegistryChallenge(t *testing.T) {
tests := []struct {
input string
wantRealm, wantService, wantScope string
}{
{
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
"https://auth.example.com/token", "registry", "repo:foo:pull",
},
{
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
"https://r.ollama.ai/v2/token", "ollama", "-",
},
{"", "", "", ""},
}
for _, tt := range tests {
result := parseRegistryChallenge(tt.input)
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
tt.input, result.Realm, result.Service, result.Scope,
tt.wantRealm, tt.wantService, tt.wantScope)
}
}
}
func TestRegistryChallengeURL(t *testing.T) {
challenge := registryChallenge{
Realm: "https://auth.example.com/token",
Service: "registry",
Scope: "repo:foo:pull repo:bar:push",
}
u, err := challenge.URL()
if err != nil {
t.Fatalf("URL() error: %v", err)
}
if u.Host != "auth.example.com" {
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
}
if u.Path != "/token" {
t.Errorf("path = %q, want %q", u.Path, "/token")
}
q := u.Query()
if q.Get("service") != "registry" {
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
}
if scopes := q["scope"]; len(scopes) != 2 {
t.Errorf("scope count = %d, want 2", len(scopes))
}
if q.Get("ts") == "" {
t.Error("missing ts")
}
if q.Get("nonce") == "" {
t.Error("missing nonce")
}
// Nonces should differ between calls
u2, _ := challenge.URL()
if q.Get("nonce") == u2.Query().Get("nonce") {
t.Error("nonce should be unique per call")
}
}
func TestRegistryChallengeURLInvalid(t *testing.T) {
challenge := registryChallenge{Realm: "://invalid"}
if _, err := challenge.URL(); err == nil {
t.Error("expected error for invalid URL")
}
}

View File

@@ -95,11 +95,48 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
} }
const ( const (
numDownloadParts = 16 // numDownloadParts is the default number of concurrent download parts for standard downloads
numDownloadParts = 16
// numHFDownloadParts is the reduced number of concurrent download parts for HuggingFace
// downloads to avoid triggering rate limits (HTTP 429 errors). See GitHub issue #13297.
numHFDownloadParts = 4
minDownloadPartSize int64 = 100 * format.MegaByte minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte maxDownloadPartSize int64 = 1000 * format.MegaByte
) )
// isHuggingFaceURL returns true if the URL is from a HuggingFace domain.
// This includes:
// - huggingface.co (main domain)
// - *.huggingface.co (subdomains like cdn-lfs.huggingface.co)
// - hf.co (shortlink domain)
// - *.hf.co (CDN domains like cdn-lfs.hf.co, cdn-lfs3.hf.co)
func isHuggingFaceURL(u *url.URL) bool {
if u == nil {
return false
}
host := strings.ToLower(u.Hostname())
return host == "huggingface.co" ||
strings.HasSuffix(host, ".huggingface.co") ||
host == "hf.co" ||
strings.HasSuffix(host, ".hf.co")
}
// getNumDownloadParts returns the number of concurrent download parts to use
// for the given URL. HuggingFace URLs use reduced concurrency (default 4) to
// avoid triggering rate limits. This can be overridden via the OLLAMA_HF_CONCURRENCY
// environment variable. For non-HuggingFace URLs, returns the standard concurrency (16).
func getNumDownloadParts(u *url.URL) int {
if isHuggingFaceURL(u) {
if v := os.Getenv("OLLAMA_HF_CONCURRENCY"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return numHFDownloadParts
}
return numDownloadParts
}
func (p *blobDownloadPart) Name() string { func (p *blobDownloadPart) Name() string {
return strings.Join([]string{ return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N), p.blobDownload.Name, "partial", strconv.Itoa(p.N),
@@ -271,7 +308,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
} }
g, inner := errgroup.WithContext(ctx) g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts) concurrency := getNumDownloadParts(directURL)
if concurrency != numDownloadParts {
slog.Info(fmt.Sprintf("using reduced concurrency (%d) for HuggingFace download", concurrency))
}
g.SetLimit(concurrency)
for i := range b.Parts { for i := range b.Parts {
part := b.Parts[i] part := b.Parts[i]
if part.Completed.Load() == part.Size { if part.Completed.Load() == part.Size {

194
server/download_test.go Normal file
View File

@@ -0,0 +1,194 @@
package server
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsHuggingFaceURL(t *testing.T) {
tests := []struct {
name string
url string
expected bool
}{
{
name: "nil url",
url: "",
expected: false,
},
{
name: "huggingface.co main domain",
url: "https://huggingface.co/some/model",
expected: true,
},
{
name: "cdn-lfs.huggingface.co subdomain",
url: "https://cdn-lfs.huggingface.co/repos/abc/123",
expected: true,
},
{
name: "cdn-lfs3.hf.co CDN domain",
url: "https://cdn-lfs3.hf.co/repos/abc/123",
expected: true,
},
{
name: "hf.co shortlink domain",
url: "https://hf.co/model",
expected: true,
},
{
name: "uppercase HuggingFace domain",
url: "https://HUGGINGFACE.CO/model",
expected: true,
},
{
name: "mixed case HF domain",
url: "https://Cdn-Lfs.HF.Co/repos",
expected: true,
},
{
name: "ollama registry",
url: "https://registry.ollama.ai/v2/library/llama3",
expected: false,
},
{
name: "github.com",
url: "https://github.com/ollama/ollama",
expected: false,
},
{
name: "fake huggingface domain",
url: "https://nothuggingface.co/model",
expected: false,
},
{
name: "fake hf domain",
url: "https://nothf.co/model",
expected: false,
},
{
name: "huggingface in path not host",
url: "https://example.com/huggingface.co/model",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var u *url.URL
if tc.url != "" {
var err error
u, err = url.Parse(tc.url)
if err != nil {
t.Fatalf("failed to parse URL: %v", err)
}
}
got := isHuggingFaceURL(u)
assert.Equal(t, tc.expected, got)
})
}
}
func TestGetNumDownloadParts(t *testing.T) {
tests := []struct {
name string
url string
envValue string
expected int
description string
}{
{
name: "nil url returns default",
url: "",
envValue: "",
expected: numDownloadParts,
description: "nil URL should return standard concurrency",
},
{
name: "ollama registry returns default",
url: "https://registry.ollama.ai/v2/library/llama3",
envValue: "",
expected: numDownloadParts,
description: "Ollama registry should use standard concurrency",
},
{
name: "huggingface returns reduced default",
url: "https://huggingface.co/model/repo",
envValue: "",
expected: numHFDownloadParts,
description: "HuggingFace should use reduced concurrency",
},
{
name: "hf.co CDN returns reduced default",
url: "https://cdn-lfs3.hf.co/repos/abc/123",
envValue: "",
expected: numHFDownloadParts,
description: "HuggingFace CDN should use reduced concurrency",
},
{
name: "huggingface with env override",
url: "https://huggingface.co/model/repo",
envValue: "2",
expected: 2,
description: "OLLAMA_HF_CONCURRENCY should override default",
},
{
name: "huggingface with higher env override",
url: "https://huggingface.co/model/repo",
envValue: "8",
expected: 8,
description: "OLLAMA_HF_CONCURRENCY can be set higher than default",
},
{
name: "huggingface with invalid env (non-numeric)",
url: "https://huggingface.co/model/repo",
envValue: "invalid",
expected: numHFDownloadParts,
description: "Invalid OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "huggingface with invalid env (zero)",
url: "https://huggingface.co/model/repo",
envValue: "0",
expected: numHFDownloadParts,
description: "Zero OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "huggingface with invalid env (negative)",
url: "https://huggingface.co/model/repo",
envValue: "-1",
expected: numHFDownloadParts,
description: "Negative OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "non-huggingface ignores env",
url: "https://registry.ollama.ai/v2/library/llama3",
envValue: "2",
expected: numDownloadParts,
description: "OLLAMA_HF_CONCURRENCY should not affect non-HF URLs",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Set or clear the environment variable
if tc.envValue != "" {
t.Setenv("OLLAMA_HF_CONCURRENCY", tc.envValue)
}
var u *url.URL
if tc.url != "" {
var err error
u, err = url.Parse(tc.url)
if err != nil {
t.Fatalf("failed to parse URL: %v", err)
}
}
got := getNumDownloadParts(u)
assert.Equal(t, tc.expected, got, tc.description)
})
}
}

View File

@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm, Realm: challenge.Realm,
Service: challenge.Service, Service: challenge.Service,
Scope: challenge.Scope, Scope: challenge.Scope,
}, base.Host) })
} }
if err := transfer.Download(ctx, transfer.DownloadOptions{ if err := transfer.Download(ctx, transfer.DownloadOptions{
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm, Realm: challenge.Realm,
Service: challenge.Service, Service: challenge.Service,
Scope: challenge.Scope, Scope: challenge.Scope,
}, base.Host) })
} }
return transfer.Upload(ctx, transfer.UploadOptions{ return transfer.Upload(ctx, transfer.UploadOptions{
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
// Handle authentication error with one retry // Handle authentication error with one retry
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host) token, err := getAuthorizationToken(ctx, challenge)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -51,6 +51,7 @@ import (
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
) )
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -163,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil return runner.llama, model, &opts, nil
} }
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) { func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey() pubKey, err := auth.GetPublicKey()
if err != nil { if err != nil {
@@ -190,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model) name := model.ParseName(req.Model)
if !name.IsValid() { if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with // Ideally this is "invalid model name" but we're keeping with
@@ -1557,12 +1587,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Experimental OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", s.handleImageGeneration)
// Inference (Anthropic compatibility) // Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil { if rc != nil {
// wrap old with new // wrap old with new
rs := &registry.Local{ rs := &registry.Local{
@@ -1880,62 +1911,6 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b)) return "call_" + strings.ToLower(string(b))
} }
func (s *Server) handleImageGeneration(c *gin.Context) {
var req struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
m, err := GetModel(req.Model)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Parse size (e.g., "1024x768") into width and height
width, height := int32(1024), int32(1024)
if req.Size != "" {
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
return
}
}
var image []byte
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: width,
Height: height,
}, func(resp llm.CompletionResponse) {
if len(resp.Image) > 0 {
image = resp.Image
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"created": time.Now().Unix(),
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
})
}
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()

View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"os" "os"
"slices"
"testing" "testing"
"time" "time"
@@ -16,6 +17,7 @@ import (
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@@ -805,8 +807,32 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
func (s *mockLlm) HasExited() bool { return false } func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil } func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model // TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted when idle. // loaded in the scheduler can be evicted by a language model request.
func TestImageGenRunnerCanBeEvicted(t *testing.T) { func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond) ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done() defer done()
@@ -838,59 +864,3 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
require.NotNil(t, runner) require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath) require.Equal(t, "/fake/image/model", runner.modelPath)
} }
// TestImageGenSchedulerCoexistence verifies that image generation models
// can coexist with language models in the scheduler and VRAM is tracked correctly.
func TestImageGenSchedulerCoexistence(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Load both an imagegen runner and a language model runner
imageGenRunner := &runnerRef{
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
modelPath: "/fake/flux/model",
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
langModelRunner := &runnerRef{
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
modelPath: "/fake/llama3/model",
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
s.loadedMu.Lock()
s.loaded["/fake/flux/model"] = imageGenRunner
s.loaded["/fake/llama3/model"] = langModelRunner
s.loadedMu.Unlock()
// Verify both are loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
require.NotNil(t, s.loaded["/fake/flux/model"])
require.NotNil(t, s.loaded["/fake/llama3/model"])
s.loadedMu.Unlock()
// Verify updateFreeSpace accounts for both
gpus := []ml.DeviceInfo{
{
DeviceID: ml.DeviceID{Library: "Metal"},
TotalMemory: 24 * format.GigaByte,
FreeMemory: 24 * format.GigaByte,
},
}
s.updateFreeSpace(gpus)
// Free memory should be reduced by both models
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
require.Equal(t, expectedFree, gpus[0].FreeMemory)
}

View File

@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
w.Rollback() w.Rollback()
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host) token, err := getAuthorizationToken(ctx, challenge)
if err != nil { if err != nil {
return err return err
} }

50
x/README.md Normal file
View File

@@ -0,0 +1,50 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
### Building ollama-mlx
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
#### macOS (Apple Silicon and Intel)
```bash
# Build MLX backend libraries
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
go build -tags mlx -o ollama-mlx .
```
#### Linux (CUDA)
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
```bash
# Build MLX backend libraries with CUDA support
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
go build -tags mlx -o ollama-mlx .
```
#### Using build scripts
The build scripts automatically create the `ollama-mlx` binary:
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
## Image Generation
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.

View File

@@ -25,6 +25,14 @@ import (
"github.com/ollama/ollama/x/tools" "github.com/ollama/ollama/x/tools"
) )
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants // Tool output capping constants
const ( const (
// localModelTokenLimit is the token limit for local models (smaller context). // localModelTokenLimit is the token limit for local models (smaller context).
@@ -648,7 +656,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
Placeholder: "Send a message (/? for help)", Placeholder: "Send a message (/? for help)",
AltPlaceholder: "Press Enter to send", AltPlaceholder: `Use """ to end multi-line input`,
}) })
if err != nil { if err != nil {
return err return err
@@ -699,6 +707,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var sb strings.Builder var sb strings.Builder
var format string var format string
var system string var system string
var multiline MultilineState = MultilineNone
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@@ -712,12 +721,37 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
} }
scanner.Prompt.UseAlt = false scanner.Prompt.UseAlt = false
sb.Reset() sb.Reset()
multiline = MultilineNone
continue continue
case err != nil: case err != nil:
return err return err
} }
switch { switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil return nil
case strings.HasPrefix(line, "/clear"): case strings.HasPrefix(line, "/clear"):
@@ -826,18 +860,41 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
options[args[2]] = fp[args[2]] options[args[2]] = fp[args[2]]
case "system": case "system":
if len(args) < 3 { if len(args) < 3 {
fmt.Println("Usage: /set system <message>") fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
continue continue
} }
system = strings.Join(args[2:], " ") multiline = MultilineSystem
newMessage := api.Message{Role: "system", Content: system}
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(messages) > 0 && messages[len(messages)-1].Role == "system" { if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage messages[len(messages)-1] = newMessage
} else { } else {
messages = append(messages, newMessage) messages = append(messages, newMessage)
} }
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset()
continue continue
default: default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
@@ -1024,7 +1081,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line) sb.WriteString(line)
} }
if sb.Len() > 0 { if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()} newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage) messages = append(messages, newMessage)

231
x/imagegen/api/handler.go Normal file
View File

@@ -0,0 +1,231 @@
package api
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/imagegen"
)
// RunnerScheduler is the interface for scheduling a model runner.
// This is implemented by server.Server to avoid circular imports.
type RunnerScheduler interface {
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
}
// RegisterRoutes registers the image generation API routes.
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
r.POST("/v1/images/generations", func(c *gin.Context) {
ImageGenerationHandler(c, scheduler)
})
}
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
var req ImageGenerationRequest
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Validate required fields
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
return
}
if req.Prompt == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
return
}
// Apply defaults
if req.N == 0 {
req.N = 1
}
if req.Size == "" {
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
}
// Verify model exists
if imagegen.ResolveModelName(req.Model) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
} else {
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
}
}
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
c.SSEvent("progress", progress)
c.Writer.Flush()
}
}
})
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
c.SSEvent("done", buildResponse(imageBase64, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
return
}
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
}
return ""
}
func parseProgress(content string) ImageProgressEvent {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imageBase64, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imageBase64 == "" {
return resp
}
if format == "url" {
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
} else {
resp.Data[0].B64JSON = imageBase64
}
return resp
}
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
ch := make(chan any)
go func() {
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
// Log error but don't block - channel is already being consumed
_ = err
}
}()
streamFn(c, ch)
}
// GenerateResponse matches api.GenerateResponse structure for streaming.
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}

31
x/imagegen/api/types.go Normal file
View File

@@ -0,0 +1,31 @@
// Package api provides OpenAI-compatible image generation API types.
package api
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
// ImageData contains the generated image data.
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageProgressEvent is sent during streaming to indicate generation progress.
type ImageProgressEvent struct {
Step int `json:"step"`
Total int `json:"total"`
}

View File

@@ -7,6 +7,7 @@ package imagegen
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -38,17 +39,77 @@ func DefaultOptions() ImageGenOptions {
return ImageGenOptions{ return ImageGenOptions{
Width: 1024, Width: 1024,
Height: 1024, Height: 1024,
Steps: 0, // 0 means model default Steps: 9,
Seed: 0, // 0 means random Seed: 0, // 0 means random
} }
} }
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}
// RegisterFlags adds image generation flags to the given command. // RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models. // Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) { func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width") cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height") cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)") cmd.Flags().Int("steps", 9, "Denoising steps")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)") cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt") cmd.Flags().String("negative", "", "Negative prompt")
cmd.Flags().MarkHidden("width") cmd.Flags().MarkHidden("width")
@@ -91,18 +152,23 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
} }
// generateImageWithOptions generates an image with the given options. // generateImageWithOptions generates an image with the given options.
// Note: opts are currently unused as the native API doesn't support size parameters. func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{ req := &api.GenerateRequest{
Model: modelName, Model: modelName,
Prompt: prompt, Prompt: prompt,
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
} }
if keepAlive != nil { if keepAlive != nil {
req.KeepAlive = keepAlive req.KeepAlive = keepAlive

View File

@@ -12,7 +12,6 @@ import (
"path/filepath" "path/filepath"
"runtime/pprof" "runtime/pprof"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3" "github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss" "github.com/ollama/ollama/x/imagegen/models/gpt_oss"
@@ -49,7 +48,7 @@ func main() {
// Image generation params // Image generation params
width := flag.Int("width", 1024, "Image width") width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height") height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)") steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed") seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path") out := flag.String("output", "output.png", "Output path")

View File

@@ -175,63 +175,3 @@ func (m *ModelManifest) HasTensorLayers() bool {
} }
return false return false
} }
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}

View File

@@ -95,3 +95,8 @@ func EstimateVRAM(modelName string) uint64 {
} }
return 21 * GB return 21 * GB
} }
// HasTensorLayers checks if the given model has tensor layers.
func HasTensorLayers(modelName string) bool {
return ResolveModelName(modelName) != ""
}

View File

@@ -94,6 +94,13 @@ func TestEstimateVRAMDefault(t *testing.T) {
} }
} }
func TestHasTensorLayers(t *testing.T) {
// Non-existent model should return false
if HasTensorLayers("nonexistent-model") {
t.Error("HasTensorLayers() should return false for non-existent model")
}
}
func TestResolveModelName(t *testing.T) { func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string // Non-existent model should return empty string
result := ResolveModelName("nonexistent-model") result := ResolveModelName("nonexistent-model")

View File

@@ -1,5 +1,7 @@
# MLX Memory Management # MLX Memory Management
| This package will get consolidated with `x/ml/backend/mlx` in the future.
## Automatic Tracking ## Automatic Tracking
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed. All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.

View File

@@ -9,7 +9,6 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache" "github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -173,7 +172,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
cfg.Height = 1024 cfg.Height = 1024
} }
if cfg.Steps <= 0 { if cfg.Steps <= 0 {
cfg.Steps = 50 cfg.Steps = 30
} }
if cfg.CFGScale <= 0 { if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0 cfg.CFGScale = 4.0

View File

@@ -194,7 +194,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
cfg.Height = 1024 cfg.Height = 1024
} }
if cfg.Steps <= 0 { if cfg.Steps <= 0 {
cfg.Steps = 9 // Z-Image turbo default cfg.Steps = 9 // Turbo default
} }
if cfg.CFGScale <= 0 { if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0 cfg.CFGScale = 4.0

View File

@@ -136,8 +136,16 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Model applies its own defaults for width/height/steps // Apply defaults
// Only seed needs to be set here if not provided if req.Width <= 0 {
req.Width = 1024
}
if req.Height <= 0 {
req.Height = 1024
}
if req.Steps <= 0 {
req.Steps = 9
}
if req.Seed <= 0 { if req.Seed <= 0 {
req.Seed = time.Now().UnixNano() req.Seed = time.Now().UnixNano()
} }

View File

@@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -26,11 +25,6 @@ import (
) )
// Server wraps an image generation subprocess to implement llm.LlamaServer. // Server wraps an image generation subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
// like any other model. The plan is to eventually bring this into the llm/ package
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
// separate allows for independent iteration on image generation support.
type Server struct { type Server struct {
mu sync.Mutex mu sync.Mutex
cmd *exec.Cmd cmd *exec.Cmd
@@ -43,6 +37,22 @@ type Server struct {
lastErrLock sync.Mutex lastErrLock sync.Mutex
} }
// completionRequest is sent to the subprocess
type completionRequest struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// completionResponse is received from the subprocess
type completionResponse struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"`
Done bool `json:"done"`
}
// NewServer spawns a new image generation subprocess and waits until it's ready. // NewServer spawns a new image generation subprocess and waits until it's ready.
func NewServer(modelName string) (*Server, error) { func NewServer(modelName string) (*Server, error) {
// Validate platform support before attempting to start // Validate platform support before attempting to start
@@ -129,6 +139,7 @@ func NewServer(modelName string) (*Server, error) {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
slog.Warn("image-runner", "msg", line) slog.Warn("image-runner", "msg", line)
// Capture last error line for better error reporting
s.lastErrLock.Lock() s.lastErrLock.Lock()
s.lastErr = line s.lastErr = line
s.lastErrLock.Unlock() s.lastErrLock.Unlock()
@@ -160,6 +171,7 @@ func (s *Server) ModelPath() string {
return s.modelName return s.modelName
} }
// Load is called by the scheduler after the server is created.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) { func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil return nil, nil
} }
@@ -192,16 +204,20 @@ func (s *Server) waitUntilRunning() error {
for { for {
select { select {
case err := <-s.done: case err := <-s.done:
// Include recent stderr lines for better error context // Include last stderr line for better error context
errMsg := s.getLastErr() s.lastErrLock.Lock()
if errMsg != "" { lastErr := s.lastErr
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err) s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
} }
return fmt.Errorf("image runner exited unexpectedly: %w", err) return fmt.Errorf("image runner exited unexpectedly: %w", err)
case <-timeout: case <-timeout:
errMsg := s.getLastErr() s.lastErrLock.Lock()
if errMsg != "" { lastErr := s.lastErr
return fmt.Errorf("timeout waiting for image runner: %s", errMsg) s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
} }
return errors.New("timeout waiting for image runner to start") return errors.New("timeout waiting for image runner to start")
case <-ticker.C: case <-ticker.C:
@@ -213,39 +229,44 @@ func (s *Server) waitUntilRunning() error {
} }
} }
// getLastErr returns the last stderr line. // WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
func (s *Server) getLastErr() string { func (s *Server) WaitUntilRunning(ctx context.Context) error {
s.lastErrLock.Lock() return nil
defer s.lastErrLock.Unlock()
return s.lastErr
} }
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil } // Completion generates an image from the prompt via the subprocess.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed // Build request
if seed == 0 { creq := completionRequest{
seed = time.Now().UnixNano()
}
// Build request for subprocess
creq := struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
Prompt: req.Prompt, Prompt: req.Prompt,
Width: req.Width, Width: 1024,
Height: req.Height, Height: 1024,
Seed: seed, Steps: 9,
Seed: time.Now().UnixNano(),
} }
if req.Options != nil {
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
creq.Width = int32(req.Options.NumCtx)
}
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
creq.Height = int32(req.Options.NumGPU)
}
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
creq.Steps = req.Options.NumPredict
}
if req.Options.Seed > 0 {
creq.Seed = int64(req.Options.Seed)
}
}
// Encode request body
body, err := json.Marshal(creq) body, err := json.Marshal(creq)
if err != nil { if err != nil {
return err return err
} }
// Send request to subprocess
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil { if err != nil {
@@ -260,40 +281,30 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("request failed: %d", resp.StatusCode) return fmt.Errorf("completion request failed: %d", resp.StatusCode)
} }
// Stream responses - use large buffer for base64 image data
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() { for scanner.Scan() {
// Parse subprocess response (has singular "image" field) var cresp completionResponse
var raw struct { if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
continue continue
} }
// Convert to llm.CompletionResponse content := cresp.Content
cresp := llm.CompletionResponse{ // If this is the final response with an image, encode it in the content
Content: raw.Content, if cresp.Done && cresp.Image != "" {
Done: raw.Done, content = "IMAGE_BASE64:" + cresp.Image
Step: raw.Step,
Total: raw.Total,
}
if raw.Image != "" {
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
cresp.Image = data
}
} }
fn(cresp) fn(llm.CompletionResponse{
Content: content,
Done: cresp.Done,
})
if cresp.Done { if cresp.Done {
return nil break
} }
} }
@@ -335,18 +346,22 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize return s.vramSize
} }
// Embedding is not supported for image generation models.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) { func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("not supported") return nil, 0, errors.New("embedding not supported for image generation models")
} }
// Tokenize is not supported for image generation models.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
return nil, errors.New("not supported") return nil, errors.New("tokenize not supported for image generation models")
} }
// Detokenize is not supported for image generation models.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) { func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("not supported") return "", errors.New("detokenize not supported for image generation models")
} }
// Pid returns the subprocess PID.
func (s *Server) Pid() int { func (s *Server) Pid() int {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -356,9 +371,17 @@ func (s *Server) Pid() int {
return -1 return -1
} }
func (s *Server) GetPort() int { return s.port } // GetPort returns the subprocess port.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil } func (s *Server) GetPort() int {
return s.port
}
// GetDeviceInfos returns nil since we don't track GPU info.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// HasExited returns true if the subprocess has exited.
func (s *Server) HasExited() bool { func (s *Server) HasExited() bool {
select { select {
case <-s.done: case <-s.done:

77
x/kvcache/cache.go Normal file
View File

@@ -0,0 +1,77 @@
package kvcache
import (
"errors"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}

797
x/kvcache/causal.go Normal file
View File

@@ -0,0 +1,797 @@
package kvcache
// import (
// "errors"
// "fmt"
// "log/slog"
// "math"
// "slices"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
// // Causal cache stores K and V tensors according to their position in the
// // sequence. Returns the history and a mask for attending to past tokens
// //
// // The tensors are of shape embed dim, kv heads, batch size
// // The mask is of shape history size, batch size
// type Causal struct {
// DType ml.DType
// // swaWindowSize is the number of tokens that will be included in the mask
// // during attention operations. swaMemorySize is the number of tokens that
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
// // for unlimited or if sliding window attention is not being used.
// swaWindowSize int32
// swaMemorySize int32
// chunkSize int32
// opts CausalOptions
// // maxBatch is the largest batch that we might receive
// maxBatch int
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // size of the current batch
// curBatchSize int
// // locations for data storage for this batch
// curLoc ml.Tensor
// // mask of the cache as used by this batch
// curMask ml.Tensor
// // the active layer for Get and Put
// curLayer int
// // locations in the cache that are needed for this batch
// curCellRange cellRange
// // curSequences is the sequences corresponding to this pass's entries in the cache
// curSequences []int
// // curPositions is the positions corresponding to this pass's entries in the cache
// curPositions []int32
// // ** cache metadata **
// // for each possible location in the cache, stores the position and set of sequences
// // that reference the data there
// cells []cacheCell
// // maps from sequence to the range of locations where it is stored in the cache
// cellRanges map[int]cellRange
// // ** cache data storage **
// shiftFn shiftFn
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// kHeadDims, vHeadDims, numKVHeads map[int]int
// }
// type cacheCell struct {
// pos int32
// sequences []int
// }
// type cellRange struct {
// min int
// max int
// }
// func NewCausalCache(shift shiftFn) *Causal {
// return &Causal{
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
// return &Causal{
// swaWindowSize: windowSize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
// return &Causal{
// swaWindowSize: windowSize,
// swaMemorySize: memorySize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
// return &Causal{
// chunkSize: chunkSize,
// shiftFn: shift,
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// kHeadDims: make(map[int]int),
// vHeadDims: make(map[int]int),
// numKVHeads: make(map[int]int),
// }
// }
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if c.config.CachePadding == 0 {
// c.config.CachePadding = 1
// }
// if c.config.MaskBatchPadding == 0 {
// c.config.MaskBatchPadding = 1
// }
// // TODO what types do we handle here?
// // if c.config.MaskDType == ml.DTypeOther {
// // c.config.MaskDType = ml.DTypeFloat32
// // }
// if c.swaWindowSize == 0 {
// c.swaWindowSize = math.MaxInt32
// }
// if c.swaMemorySize == 0 {
// c.swaMemorySize = c.swaWindowSize
// }
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
// // causing a cache break. As an optimization, only do this when we have parallel sequences
// // because the extra token will live in the batch buffer and won't get overwritten if we
// // only have a single sequence.
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
// }
// if int(c.swaMemorySize) >= capacity {
// c.swaMemorySize = math.MaxInt32
// }
// if c.swaMemorySize < c.swaWindowSize {
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
// }
// var cacheSize int
// if c.swaMemorySize == math.MaxInt32 {
// cacheSize = maxSequences * capacity
// } else {
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
// }
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
// c.cells = make([]cacheCell, cacheSize)
// c.DType = dtype
// c.cellRanges = make(map[int]cellRange)
// c.backend = backend
// c.maxBatch = maxBatch
// }
// func (c *Causal) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *Causal) Close() {
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
// // panic("XXX Causal.StartForward")
// c.curBatchSize = len(batch.Positions)
// c.curSequences = batch.Sequences
// c.curPositions = batch.Positions
// c.opts.Except = nil
// var locs []int32
// if !reserve {
// c.updateSlidingWindow()
// var err error
// locs, err = c.findLocs()
// if err != nil {
// return err
// }
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
// for i, pos := range batch.Positions {
// seq := batch.Sequences[i]
// loc := int(locs[i])
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
// seqRange, ok := c.cellRanges[seq]
// if !ok {
// seqRange = newRange()
// }
// seqRange.min = min(seqRange.min, loc)
// c.curCellRange.min = min(c.curCellRange.min, loc)
// seqRange.max = max(seqRange.max, loc)
// c.curCellRange.max = max(c.curCellRange.max, loc)
// c.cellRanges[seq] = seqRange
// }
// } else {
// // If we are reserving memory, don't update any of the cache metadata but set the size
// // to the worst case.
// locs = make([]int32, c.curBatchSize)
// for i := range locs {
// locs[i] = int32(i)
// }
// c.curCellRange.min = 0
// c.curCellRange.max = len(c.cells) - 1
// }
// // XXX Building up the locs for what's already processed (if any)
// dummyLocs := []int{}
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
// for i := range c.curBatchSize {
// enabled := !slices.Contains(c.opts.Except, i)
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
// } else {
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
// dummyLocs = append(dummyLocs, i)
// }
// }
// }
// }
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
// slog.Info("XXX Causal.StartForward", "locs", locs)
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
// c.curMask = c.buildMask(ctx)
// return nil
// }
// func newRange() cellRange {
// return cellRange{
// min: math.MaxInt,
// max: 0,
// }
// }
// // Returns a slice of locations where each token in the batch should be stored
// func (c *Causal) findLocs() ([]int32, error) {
// loc := make([]int32, 0, c.curBatchSize)
// for i := range c.cells {
// if len(c.cells[i].sequences) == 0 {
// loc = append(loc, int32(i))
// if len(loc) >= c.curBatchSize {
// return loc, nil
// }
// }
// }
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
// }
// func (c *Causal) updateSlidingWindow() {
// c.curCellRange = newRange()
// if c.swaMemorySize == math.MaxInt32 {
// for _, seq := range c.curSequences {
// if seqRange, ok := c.cellRanges[seq]; ok {
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
// }
// }
// return
// }
// type lowestPosition struct {
// pos int32
// curBatch bool
// }
// // create a map of unique sequences to the lowest position in that sequence
// lowestPos := make(map[int]lowestPosition)
// for i := range c.curPositions {
// seq := c.curSequences[i]
// lowest, ok := lowestPos[seq]
// if !ok {
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
// } else if c.curPositions[i] < lowest.pos {
// lowest.pos = c.curPositions[i]
// }
// lowestPos[seq] = lowest
// }
// // for any sequences are not part of this batch, clean up any tokens
// // that are no longer needed after the processing of the previous
// // batch
// for seq, seqRange := range c.cellRanges {
// if _, ok := lowestPos[seq]; !ok {
// var last int32
// for i := seqRange.min; i <= seqRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// last = max(last, c.cells[i].pos)
// }
// }
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
// }
// }
// // delete any entries that are beyond the window of the oldest position in the sequence
// for seq, lowest := range lowestPos {
// oldRange, ok := c.cellRanges[seq]
// if !ok {
// continue
// }
// newRange := newRange()
// for i := oldRange.min; i <= oldRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
// } else {
// newRange.min = min(newRange.min, i)
// newRange.max = max(newRange.max, i)
// }
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
// c.curCellRange.min = min(c.curCellRange.min, i)
// c.curCellRange.max = max(c.curCellRange.max, i)
// }
// }
// }
// c.cellRanges[seq] = newRange
// }
// }
// func roundDown(length, pad int) int {
// return (length / pad) * pad
// }
// func roundUp(length, pad int) int {
// return ((length + pad - 1) / pad) * pad
// }
// // Builds a mask of history x batch indicating whether for each token in the batch the
// // token in the history should apply. This is based on both the sequence and causality (the
// // position of the history is not ahead of the token in the batch).
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
// // Align and pad the two dimensions as required by the backend
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
// length := c.curCellRange.max - c.curCellRange.min + 1
// mask := make([]float32, batchSize*length)
// for i := range c.curBatchSize {
// enabled := !slices.Contains(c.opts.Except, i)
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
// }
// }
// }
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
// // has already been masked out because the sequence doesn't match.
// for i := c.curBatchSize * length; i < len(mask); i++ {
// mask[i] = float32(math.Inf(-1))
// }
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
// // if c.config.MaskDType != ml.DTypeFloat32 {
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
// // }
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
// return maskTensor
// }
// func (c *Causal) SetLayer(layer int) {
// c.curLayer = layer
// }
// type CausalOptions struct {
// // Enabled controls whether the causal mask is generated for a particular index in a batch
// Except []int
// }
// // SetCausal disables causal mask generation for a particular range of indicies in
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
// if !slices.Equal(c.opts.Except, opts.Except) {
// c.opts = opts
// if ctx != nil {
// c.curMask = c.buildMask(ctx)
// }
// }
// }
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// key := c.keys[c.curLayer]
// value := c.values[c.curLayer]
// kHeadDim := c.kHeadDims[c.curLayer]
// vHeadDim := c.vHeadDims[c.curLayer]
// numKVHeads := c.numKVHeads[c.curLayer]
// // rowSize := numKVHeads * c.curBatchSize
// // cachedSize := c.curMask.Dim(1)
// cachedSize := c.curLoc.Dim(0)
// // kCellSize := kHeadDim * numKVHeads
// // vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Get full cache", "key", key)
// slog.Info("XXX Causal.Get full cache", "value", value)
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
// // panic("XXX")
// // fmt.Fprintln(os.Stderr, key.ToString())
// // panic("full cache value")
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
// // panic("XXX")
// // if c.config.PermutedV {
// // panic("permuted")
// // // TODO not converted
// // vHeadDim := value.Dim(1)
// // elemSize := value.Stride(2)
// // value = value.AsStrided(ctx,
// // []int{numKVHeads, vHeadDim, cachedSize},
// // []int{value.Stride(0), value.Stride(1)},
// // elemSize*c.curCellRange.min,
// // )
// // } else {
// // vHeadDim := c.vHeadDims[c.curLayer]
// // rowSize := value.Stride(2)
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
// // panic("XXX")
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
// // panic("XXX")
// // }
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
// // // the 1 becomes trailing and messes up later operations
// // // This isn't the right solution, but works around it...
// // if c.curMask.Dim(1) == 1 {
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
// // }
// // fmt.Fprintln(os.Stderr, key.ToString())
// // fmt.Fprintln(os.Stderr, value.ToString())
// // panic("XXX")
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
// return key, value, c.curMask
// }
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
// kHeadDim := key.Dim(3)
// vHeadDim := value.Dim(3)
// numKVHeads := key.Dim(1)
// batchSize := key.Dim(2)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
// // panic("XXX")
// if c.curBatchSize != batchSize {
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
// }
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
// if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
// c.kHeadDims[c.curLayer] = kHeadDim
// c.vHeadDims[c.curLayer] = vHeadDim
// c.numKVHeads[c.curLayer] = numKVHeads
// }
// if _, ok := c.values[c.curLayer]; !ok {
// // if c.config.PermutedV {
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
// // } else {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
// // }
// }
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
// // panic("XXX")
// // curLoc := 0 // TODO c.curLoc is now a tensor
// // kSize := numKVHeads * kHeadDim
// // vSize := numKVHeads * vHeadDim
// // start := []int{int(curLoc), 0}
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
// // strides := []int{1, 1}
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
// // panic("input value")
// // fmt.Fprintln(os.Stderr, t.ToString())
// // panic("XXX")
// // if c.config.PermutedV {
// // panic("permuted")
// // // TODO not adjusted
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
// // value = value.Transpose(ctx, 2, 0, 1, 3)
// // valueCache := c.values[c.curLayer]
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
// // } else {
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
// // }
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
// // panic("XXX")
// }
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
// seqRange := newRange()
// for i := range c.cells {
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
// if slices.Contains(c.cells[i].sequences, dstSeq) {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
// }
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
// if i < seqRange.min {
// seqRange.min = i
// }
// if i > seqRange.max {
// seqRange.max = i
// }
// }
// }
// c.cellRanges[dstSeq] = seqRange
// }
// func (c *Causal) CanResume(seq int, pos int32) bool {
// if c.swaMemorySize == math.MaxInt32 {
// return true
// }
// seqRange, ok := c.cellRanges[seq]
// if !ok {
// return false
// }
// // for sliding window, check that the window of the new sequence is contained in
// // the window of what we are storing
// var first int32 = math.MaxInt32
// var last int32 = -1
// for i := seqRange.min; i <= seqRange.max; i++ {
// if slices.Contains(c.cells[i].sequences, seq) {
// first = min(first, c.cells[i].pos)
// last = max(last, c.cells[i].pos)
// }
// }
// if last == -1 {
// return false
// }
// posWindowStart := max(0, pos-c.swaWindowSize)
// return posWindowStart >= first && pos <= last+1
// }
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
// if c.shiftFn == nil {
// return ErrNotSupported
// }
// seqRange := c.cellRanges[seq]
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
// size := min(seqRange.max-start+1, c.maxBatch)
// offsets := make([]int32, size)
// var batchFirst, batchLast int
// batchFirst = -1
// for i := range offsets {
// cell := c.cells[start+i]
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
// offsets[i] = offset
// if batchFirst < 0 {
// batchFirst = i
// }
// batchLast = i
// }
// }
// if batchFirst < 0 {
// continue
// }
// offsets = offsets[batchFirst : batchLast+1]
// slog.Info("XXX Causal.shift creating new temporary context")
// ctx := c.backend.NewContext()
// kShift := ctx.Input().FromInts(offsets, len(offsets))
// for i, key := range c.keys {
// if key == nil {
// continue
// }
// kHeadDim := key.Dim(2)
// numKVHeads := key.Dim(1)
// rowSize := key.Stride(0)
// key = key.AsStrided(ctx,
// []int{len(offsets), numKVHeads, kHeadDim},
// []int{key.Stride(0), key.Stride(1)},
// rowSize*(start+batchFirst),
// )
// roped, err := c.shiftFn(ctx, i, key, kShift)
// if err != nil {
// ctx.Close()
// return err
// }
// ctx.Forward(roped.Copy(ctx, key))
// }
// ctx.Compute()
// ctx.Close()
// }
// return nil
// }
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
// // should return an error, which will trigger the runner to evaluate the full history and
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
// // results in use after free, so we don't do it for now.
// var offset int32
// if endIndex != math.MaxInt32 {
// offset = beginIndex - endIndex
// }
// seqRange := newRange()
// for i := range c.cells {
// if slices.Contains(c.cells[i].sequences, seq) {
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
// } else {
// if c.cells[i].pos >= endIndex {
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
// return errors.New("shifting cells shared by multiple sequences not supported")
// }
// c.cells[i].pos += offset
// }
// if i < seqRange.min {
// seqRange.min = i
// }
// if i > seqRange.max {
// seqRange.max = i
// }
// }
// }
// }
// if seqRange == newRange() {
// delete(c.cellRanges, seq)
// return nil
// }
// c.cellRanges[seq] = seqRange
// if endIndex != math.MaxInt32 {
// err := c.shift(seq, endIndex+offset, offset)
// if err != nil {
// return err
// }
// }
// return nil
// }

973
x/kvcache/causal_test.go Normal file
View File

@@ -0,0 +1,973 @@
package kvcache
// import (
// "fmt"
// "math"
// "slices"
// "testing"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type testCase struct {
// name string
// in []float32
// inShape []int
// seqs []int
// pos []int32
// expected []float32
// expectedShape []int
// expectedMask []float32
// }
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
// t.Helper()
// for _, permuted := range []bool{false, true} {
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
// fn(t, &testBackend{permutedV: permuted})
// })
// }
// }
// func TestStore(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// inShape: []int{2, 3, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// expectedShape: []int{2, 3, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{115, 215, 125, 225, 135, 235},
// inShape: []int{2, 3, 1},
// seqs: []int{0},
// pos: []int32{4},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
// expectedShape: []int{2, 3, 5},
// expectedMask: []float32{0, 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWA(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 6, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWASeparateBatches(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "First seq 0",
// in: []float32{1, 2},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{0, 1},
// expected: []float32{1, 2},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 0",
// in: []float32{3, 4},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{2, 3},
// expected: []float32{2, 3, 4},
// expectedShape: []int{1, 1, 3},
// expectedMask: []float32{
// 0, 0, x,
// x, 0, 0,
// },
// },
// {
// name: "First seq 1",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{0, 1},
// expected: []float32{5, 6},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 1",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{2, 3},
// expected: []float32{6, 3, 4, 7, 8},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// x, x, x, 0, 0,
// },
// },
// {
// name: "Third seq 0",
// in: []float32{9, 10},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{9, 10, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWAMemCache(1, 3, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 2, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestChunkedAttention(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewChunkedAttentionCache(2, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// testCache(
// t, backend, cache,
// []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6, 7},
// inShape: []int{1, 1, 3},
// seqs: []int{0, 0, 0},
// pos: []int32{4, 5, 6},
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
// expectedShape: []int{1, 1, 7},
// expectedMask: []float32{
// x, x, x, x, 0, x, x,
// x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, 0,
// },
// },
// {
// name: "ThirdBatch",
// in: []float32{8, 9},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{7, 8},
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
// expectedShape: []int{1, 1, 9},
// expectedMask: []float32{
// x, x, x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, x, x, 0,
// },
// },
// },
// )
// })
// }
// func TestSequences(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{2, 2},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestRemove(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// return key.Add(ctx, shift), nil
// })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err := cache.Remove(0, 1, math.MaxInt32)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveEnd",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{1, 2},
// expected: []float32{1, 5, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, 0, x, x, x,
// x, x, 0, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err = cache.Remove(0, 0, 1)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveMiddle",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{1, 2},
// expected: []float32{7, 4, 3, 4, 6, 8},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{
// 0, 0, x, x, x, x,
// 0, 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestCopy(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// cache.CopyPrefix(0, 1, 2)
// tests = []testCase{
// {
// name: "Copy",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{3, 4},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
// for _, test := range tests {
// t.Run(test.name, func(t *testing.T) {
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
// if err != nil {
// panic(err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats(test.in, test.inShape...)
// cache.Put(context, tensor, tensor)
// out, _, mask := cache.Get(context)
// context.Forward(out, mask).Compute(out, mask)
// if !slices.Equal(out.Floats(), test.expected) {
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
// }
// if !slices.Equal(out.Shape(), test.expectedShape) {
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
// }
// if !slices.Equal(mask.Floats(), test.expectedMask) {
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
// }
// })
// }
// }
// func TestCanResume(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// cache := NewSWACache(windowSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4},
// Sequences: []int{0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
// cache.Put(context, tensor, tensor)
// // with window size 4, nothing has slid out of the window yet
// if !cache.CanResume(0, 0) {
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
// }
// if !cache.CanResume(0, 1) {
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
// }
// if !cache.CanResume(0, 2) {
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
// }
// if !cache.CanResume(0, 3) {
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
// }
// if !cache.CanResume(0, 4) {
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
// }
// // shift window by adding position 5
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{5},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
// }
// })
// }
// func TestCanResumeSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// memSize := int32(5)
// cache := NewSWAMemCache(windowSize, memSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
// cache.Put(context, tensor, tensor)
// // shift window by adding position 7
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{7},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 6) {
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
// }
// if !cache.CanResume(0, 7) {
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
// }
// })
// }
// type testBackend struct {
// ml.Backend
// permutedV bool
// }
// func (b *testBackend) NewContext() ml.Context {
// return &testContext{}
// }
// func (b *testBackend) NewContextSize(int) ml.Context {
// return &testContext{}
// }
// func (b *testBackend) CacheConfig() ml.CacheConfig {
// return ml.CacheConfig{PermutedV: b.permutedV}
// }
// type testContext struct {
// ml.Context
// }
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
// total := 0
// if len(shape) > 0 {
// total = 1
// for _, s := range shape {
// total *= s
// }
// }
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
// }
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
// return c.Empty(dtype, shape...)
// }
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
// copy(t.data, s)
// return t
// }
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
// f := make([]float32, len(s))
// for i := range f {
// f[i] = float32(s[i])
// }
// out := c.FromFloats(f, shape...)
// out.(*testTensor).dtype = ml.DTypeI32
// return out
// }
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
// s := make([]float32, 0, int((stop-start)/step))
// for i := start; i < stop; i += step {
// s = append(s, i)
// }
// out := c.FromFloats(s, len(s))
// out.(*testTensor).dtype = dtype
// return out
// }
// func (c *testContext) Input() ml.Context { return c }
// func (c *testContext) Layer(int) ml.Context { return c }
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
// func (c *testContext) Compute(...ml.Tensor) {}
// func (c *testContext) Reserve() {}
// func (c *testContext) MaxGraphNodes() int {
// return 10
// }
// func (c *testContext) Close() {}
// type testTensor struct {
// ml.Tensor
// dtype ml.DType
// elementSize int
// data []float32
// shape []int
// }
// func (t *testTensor) Dim(n int) int {
// return t.shape[n]
// }
// func (t *testTensor) Stride(n int) int {
// stride := t.elementSize
// for i := range n {
// stride *= t.shape[i]
// }
// return stride
// }
// func (t *testTensor) Shape() []int {
// return t.shape
// }
// func (t *testTensor) DType() ml.DType {
// return t.dtype
// }
// func (t *testTensor) Floats() []float32 {
// out := make([]float32, len(t.data))
// copy(out, t.data)
// return out
// }
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = -t.data[i]
// }
// return out
// }
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
// }
// return out
// }
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: t.data,
// shape: shape,
// }
// }
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
// offset /= t.elementSize
// var s []int
// switch len(shape) {
// case 1:
// s = []int{shape[0]}
// case 3:
// s = []int{shape[0], shape[2]}
// case 5:
// s = []int{shape[0], shape[2], shape[4]}
// default:
// panic("unsupported number of dimensions")
// }
// context := &testContext{}
// view := context.Empty(t.dtype, s...).(*testTensor)
// view.data = t.data[offset : offset+len(view.data)]
// return view
// }
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
// if len(t.shape) > 4 || len(order) > 4 {
// panic("permute only supports up to 4 dimensions")
// }
// if len(order) != len(t.shape) && len(order) != 4 {
// panic("invalid number of dimensions for permute")
// }
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
// orderFull := append(make([]int, 0, 4), order...)
// for len(orderFull) < 4 {
// orderFull = append(orderFull, len(orderFull))
// }
// seen := [4]bool{}
// shape4 := [4]int{1, 1, 1, 1}
// for i := 0; i < len(t.shape) && i < 4; i++ {
// shape4[i] = t.shape[i]
// }
// newShape4 := [4]int{1, 1, 1, 1}
// for axis := range 4 {
// dst := orderFull[axis]
// if dst < 0 || dst >= 4 {
// panic("invalid axis for permute")
// }
// if seen[dst] {
// panic("duplicate axis for permute")
// }
// seen[dst] = true
// newShape4[dst] = shape4[axis]
// }
// total := len(t.data)
// newData := make([]float32, total)
// if total > 0 {
// oldDims := shape4
// newDims := newShape4
// oldStride := [4]int{1, 1, 1, 1}
// newStride := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
// newStride[i] = newStride[i-1] * newDims[i-1]
// }
// var coords [4]int
// var newCoords [4]int
// for idx := range total {
// remainder := idx
// for axis := range 4 {
// dim := oldDims[axis]
// if dim == 0 {
// coords[axis] = 0
// continue
// }
// coords[axis] = remainder % dim
// remainder /= dim
// }
// for axis := range 4 {
// newCoords[orderFull[axis]] = coords[axis]
// }
// newIndex := 0
// for axis := range 4 {
// if newDims[axis] == 0 {
// continue
// }
// newIndex += newCoords[axis] * newStride[axis]
// }
// newData[newIndex] = t.data[idx]
// }
// }
// numDims := 4
// for numDims > 1 && newShape4[numDims-1] <= 1 {
// numDims--
// }
// newShape := make([]int, numDims)
// copy(newShape, newShape4[:numDims])
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: newData,
// shape: newShape,
// }
// }
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
// dst := t
// srcTensor := src.(*testTensor)
// idxTensor := idxs.(*testTensor)
// shapeTo4D := func(shape []int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 0; i < len(shape) && i < 4; i++ {
// out[i] = shape[i]
// }
// return out
// }
// computeStrides := func(shape [4]int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// out[i] = out[i-1] * shape[i-1]
// }
// return out
// }
// dstShape4D := shapeTo4D(dst.shape)
// srcShape4D := shapeTo4D(srcTensor.shape)
// idxShape4D := shapeTo4D(idxTensor.shape)
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
// panic("SetRows requires matching tensor shapes")
// }
// if srcShape4D[1] != idxShape4D[0] {
// panic("SetRows rows/index mismatch")
// }
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
// panic("SetRows cannot broadcast indices")
// }
// if idxShape4D[3] != 1 {
// panic("SetRows expects 1D or 2D index tensors")
// }
// dstStride := computeStrides(dstShape4D)
// srcStride := computeStrides(srcShape4D)
// idxStride := computeStrides(idxShape4D)
// numColumns := srcShape4D[0]
// numRows := srcShape4D[1]
// for dim3Index := range dstShape4D[3] {
// for dim2Index := range dstShape4D[2] {
// idxDim2 := 0
// idxDim3 := 0
// if idxShape4D[1] > 0 {
// idxDim2 = dim2Index % idxShape4D[1]
// }
// if idxShape4D[2] > 0 {
// idxDim3 = dim3Index % idxShape4D[2]
// }
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
// for row := range numRows {
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
// if idx < 0 || idx >= dstShape4D[1] {
// panic("SetRows index out of range")
// }
// srcOffset := srcBase + row*srcStride[1]
// dstOffset := dstBase + idx*dstStride[1]
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
// }
// }
// }
// return dst
// }
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// copy(t2.(*testTensor).data, t.data)
// return nil
// }

156
x/kvcache/encoder.go Normal file
View File

@@ -0,0 +1,156 @@
package kvcache
// import (
// "fmt"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Encoder cache stores K and V tensors that are position independent
// //
// // The tensors can be of any shape and will be returned as they were stored
// // The mask is currently always nil
// //
// // Not currently safe for multiple sequences
// type EncoderCache struct {
// // config controls mostly backend-specific optimizations
// config *ml.CacheConfig
// // ** current forward pass **
// // the active layer for Get and Put
// curLayer int
// // if something is stored during this pass, this
// // will be the position (but there is no guarantee
// // anything will be stored)
// curPos int32
// // curReserve indicates that this forward pass is only for
// // memory reservation and we should not update our metadata
// // based on it.
// curReserve bool
// // ** cache metadata **
// // was something stored in the cache?
// encoderCached bool
// // position of the cached data
// encoderPos int32
// // ** cache data storage **
// backend ml.Backend
// ctxs map[int]ml.Context
// keys, values map[int]ml.Tensor
// }
// func NewEncoderCache() *EncoderCache {
// return &EncoderCache{
// ctxs: make(map[int]ml.Context),
// keys: make(map[int]ml.Tensor),
// values: make(map[int]ml.Tensor),
// }
// }
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// if c.config == nil {
// var config ml.CacheConfig
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
// config = cc.CacheConfig()
// }
// c.config = &config
// }
// if maxSequences > 1 {
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
// }
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
// }
// c.backend = backend
// }
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
// if c.config != nil {
// panic("config cannot be changed after being previously set, either by the model or backend")
// }
// c.config = &config
// }
// func (c *EncoderCache) Close() {
// for _, ctx := range c.ctxs {
// ctx.Close()
// }
// }
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// // We work with the most recent image
// if len(batch.Multimodal) > 0 {
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
// }
// c.curReserve = reserve
// return nil
// }
// func (c *EncoderCache) SetLayer(layer int) {
// c.curLayer = layer
// }
// func (c *EncoderCache) EncoderCached() bool {
// return c.encoderCached
// }
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.keys[c.curLayer], c.values[c.curLayer], nil
// }
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
// if !c.curReserve {
// c.encoderPos = c.curPos
// c.encoderCached = true
// }
// if c.config.PermutedV {
// value = value.Transpose(ctx, 1, 2, 0, 3)
// }
// if _, ok := c.ctxs[c.curLayer]; !ok {
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
// }
// if _, ok := c.keys[c.curLayer]; !ok {
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
// }
// if _, ok := c.values[c.curLayer]; !ok {
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
// }
// ctx.Forward(
// key.Copy(ctx, c.keys[c.curLayer]),
// value.Copy(ctx, c.values[c.curLayer]),
// )
// }
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// panic("encoder cache does not support multiple sequences")
// }
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
// return true
// }
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
// c.encoderCached = false
// }
// return nil
// }

144
x/kvcache/mlx.go Normal file
View File

@@ -0,0 +1,144 @@
//go:build mlx
package kvcache
import (
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model/input"
)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
type MLXCausal struct {
DType ml.DType
// locations for data storage for this batch
curLocPut ml.Tensor
// locations for data storage for this batch
curLocGet ml.Tensor
// the active layer for Get and Put
curLayer int
capacity int
offset int
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
// TODO is this needed per layer, or will it always be consistent?
kHeadDims, vHeadDims, numKVHeads map[int]int
}
func NewMLXCausalCache() *MLXCausal {
return &MLXCausal{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
kHeadDims: make(map[int]int),
vHeadDims: make(map[int]int),
numKVHeads: make(map[int]int),
}
}
func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
c.DType = dtype
c.capacity = capacity
c.backend = backend
}
func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
func (c *MLXCausal) SetLayer(layer int) {
c.curLayer = layer
}
func (c *MLXCausal) Close() {
// slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
locsPut := make([]int32, len(batch.Positions))
for i := c.offset; i < len(batch.Positions); i++ {
locsPut[i-c.offset] = int32(i)
}
c.offset += len(batch.Positions)
locsGet := make([]int32, c.offset)
for i := range c.offset {
locsGet[i] = int32(i)
}
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
// slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
return nil
}
func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(3)
vHeadDim := value.Dim(3)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
kCellSize := kHeadDim * numKVHeads
vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
if _, ok := c.ctxs[c.curLayer]; !ok {
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
// slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
c.kHeadDims[c.curLayer] = kHeadDim
c.vHeadDims[c.curLayer] = vHeadDim
c.numKVHeads[c.curLayer] = numKVHeads
}
key = key.Reshape(ctx, batchSize, 1, kCellSize)
// slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
// slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
// slog.Info("XXX MLXCausal.Put ", "key", key)
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
value = value.Reshape(ctx, batchSize, 1, vCellSize)
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
}
func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := c.kHeadDims[c.curLayer]
vHeadDim := c.vHeadDims[c.curLayer]
numKVHeads := c.numKVHeads[c.curLayer]
// rowSize := numKVHeads * c.curBatchSize
// cachedSize := c.curMask.Dim(1)
cachedSize := c.curLocGet.Dim(0)
// kCellSize := kHeadDim * numKVHeads
// vCellSize := vHeadDim * numKVHeads
// slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
return key, value, nil
}
func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("not implemented")
}
func (c *MLXCausal) CanResume(seq int, pos int32) bool {
panic("not implemented")
}
func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
panic("not implemented")
}

110
x/kvcache/wrapper.go Normal file
View File

@@ -0,0 +1,110 @@
package kvcache
// import (
// "math"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// // Wrapper cache is a container for multiple types of caches,
// // such as for the encoding and decoding portions of a model.
// type WrapperCache struct {
// // caches we are wrapping
// caches []Cache
// // cache to be used for this layer
// curType int
// }
// func NewWrapperCache(caches ...Cache) *WrapperCache {
// return &WrapperCache{
// caches: caches,
// }
// }
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// for _, cache := range c.caches {
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
// }
// }
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
// for _, cache := range c.caches {
// cache.SetConfig(config)
// }
// }
// func (c *WrapperCache) Close() {
// for _, cache := range c.caches {
// cache.Close()
// }
// }
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// for i, cache := range c.caches {
// err := cache.StartForward(ctx, batch, reserve)
// if err != nil {
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
// for j := i - 1; j >= 0; j-- {
// for k := range batch.Positions {
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
// }
// }
// return err
// }
// }
// c.curType = 0
// return nil
// }
// func (c *WrapperCache) SetLayer(layer int) {
// for _, cache := range c.caches {
// cache.SetLayer(layer)
// }
// }
// func (c *WrapperCache) SetLayerType(layerType int) {
// c.curType = layerType
// }
// func (c *WrapperCache) UnderlyingCache() Cache {
// return c.caches[c.curType]
// }
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
// return c.caches[c.curType].Get(ctx)
// }
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
// c.caches[c.curType].Put(ctx, key, value)
// }
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
// for _, cache := range c.caches {
// cache.CopyPrefix(srcSeq, dstSeq, len)
// }
// }
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
// for _, cache := range c.caches {
// if !cache.CanResume(seq, pos) {
// return false
// }
// }
// return true
// }
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
// for _, cache := range c.caches {
// err := cache.Remove(seq, beginIndex, endIndex)
// if err != nil {
// return err
// }
// }
// return nil
// }

433
x/ml/backend.go Normal file
View File

@@ -0,0 +1,433 @@
package ml
import (
"fmt"
"log/slog"
"os"
"github.com/ollama/ollama/fs"
)
type Backend interface {
// Close frees all memory associated with this backend
// Close()
// Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model
// BackendMemory() BackendMemory
Config() fs.Config
Get(name string) Tensor
NewContext() Context
// NewContextSize(size int) Context
// Enumerate the devices available for inference via this backend
// BackendDevices() []DeviceInfo
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// AllocMemory causes the backend to allocate memory for the model. If
// false, this is only being used for discovering the required amount of
// memory and cannot load the model for running.
AllocMemory bool
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
// GPULayers is the set of layers to offload to GPUs
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
backends[name] = f
}
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
be := os.Getenv("OLLAMA_BACKEND")
if be == "" {
be = "mlx"
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
}
slog.Info("Loading new engine", "backend", be)
if backend, ok := backends[be]; ok {
return backend(modelPath, params)
}
return nil, fmt.Errorf("unsupported backend")
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
FromFloats(s []float32, shape ...int) Tensor
FromInts(s []int32, shape ...int) Tensor
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange(start, stop, step float32, dtype DType) Tensor
Forward(...Tensor) Context
// SetBatchSize provides a hint on the batch size to optimize processing
// Uses heuristics if not set
// SetBatchSize(int)
Compute(...Tensor)
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
// Reserve()
// MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating tensors that are
// inputs to the model (which includes things like output locations)
Input() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
// Load a tensor from "filename" safetensors file, and compare with the input tensor
// Returns error if the shape is inconsistent, or similarity measures are below 99%
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
}
type RoPEOptions struct {
Base *float32
Freqs Tensor
}
func WithRoPEBase(base float32) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Base = &base
}
}
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Freqs = freqs
}
}
type Tensor interface {
ToString() string
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
Dim(n int) int
Stride(n int) int
Shape() []int
DType() DType
// Cast(ctx Context, dtype DType) Tensor
// Bytes() []byte
Floats() []float32
Ints() []int32
// FromBytes([]byte)
// FromFloats([]float32)
// FromInts([]int32)
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
// Mul(ctx Context, t2 Tensor) Tensor
// Div(ctx Context, t2 Tensor) Tensor
Max(ctx Context, axes []int, keepDims bool) Tensor
Min(ctx Context, axes []int, keepDims bool) Tensor
Matmul(ctx Context, a2 Tensor) Tensor
// Mulmat(ctx Context, t2 Tensor) Tensor
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
// MulmatID(ctx Context, t2, ids Tensor) Tensor
// AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
// SumRows(ctx Context) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
// Sin(ctx Context) Tensor
// Cos(ctx Context) Tensor
// Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
// QuickGELU(ctx Context, up ...Tensor) Tensor
// SILU(ctx Context, up ...Tensor) Tensor
// RELU(ctx Context, up ...Tensor) Tensor
// Sigmoid(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
Transpose(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, allowColMajor bool) Tensor
// Pad(ctx Context, shape ...int) Tensor
// Stack(ctx Context, dim int, s ...Tensor) Tensor
// Repeat repeats the tensor n times along dimension dim
// Repeat(ctx Context, dim, n int) Tensor
// Concat(ctx Context, t2 Tensor, dim int) Tensor
// Rows(ctx Context, t2 Tensor) Tensor
// TODO these probably aren't actually needed - false starts on trying to wire up cache
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
Copy(ctx Context, t2 Tensor) Tensor
// Duplicate(ctx Context) Tensor
// Slice(ctx Context, dim, low, high, step int) Tensor
// Chunk(ctx Context, dim int, size int) []Tensor
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
// TopK(ctx Context, k int) Tensor
// Argsort(ctx Context) Tensor
// Mean(ctx Context) Tensor
// Variance(ctx Context) Tensor
// Stddev(ctx Context) Tensor
// Sqr(ctx Context) Tensor
// Sqrt(ctx Context) Tensor
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// type ScaledDotProductAttention interface {
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
// }
// type number interface {
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
// ~float32 | ~float64 |
// ~complex64 | ~complex128
// }
// func mul[T number](s ...T) T {
// p := T(1)
// for _, v := range s {
// p *= v
// }
// return p
// }
// type DumpOptions func(*dumpOptions)
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
// func DumpWithPrecision(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Precision = n
// }
// }
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
// // beginning and end of each dimension will be printed.
// func DumpWithThreshold(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Threshold = n
// }
// }
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
// func DumpWithEdgeItems(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.EdgeItems = n
// }
// }
// type dumpOptions struct {
// Precision, Threshold, EdgeItems int
// }
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
// for _, optsFunc := range optsFuncs {
// optsFunc(&opts)
// }
// if mul(t.Shape()...) <= opts.Threshold {
// opts.EdgeItems = math.MaxInt
// }
// switch t.DType() {
// case DTypeFloat32:
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeFloat16: // TODO other types...
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
// f32 = t.Copy(ctx, f32)
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeInt32:
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
// return strconv.FormatInt(int64(i), 10)
// })
// default:
// return "<unsupported>"
// }
// }
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
// if t.Bytes() == nil {
// ctx.Compute(t)
// }
// s := make(S, mul(t.Shape()...))
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
// panic(err)
// }
// shape := t.Shape()
// slices.Reverse(shape)
// var sb strings.Builder
// var f func([]int, int)
// f = func(dims []int, stride int) {
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
// sb.WriteString("[")
// defer func() { sb.WriteString("]") }()
// for i := 0; i < dims[0]; i++ {
// if i >= items && i < dims[0]-items {
// sb.WriteString("..., ")
// // skip to next printable element
// skip := dims[0] - 2*items
// if len(dims) > 1 {
// stride += mul(append(dims[1:], skip)...)
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
// }
// i += skip - 1
// } else if len(dims) > 1 {
// f(dims[1:], stride)
// stride += mul(dims[1:]...)
// if i < dims[0]-1 {
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
// }
// } else {
// text := fn(s[stride+i])
// if len(text) > 0 && text[0] != '-' {
// sb.WriteString(" ")
// }
// sb.WriteString(text)
// if i < dims[0]-1 {
// sb.WriteString(", ")
// }
// }
// }
// }
// f(shape, 0)
// return sb.String()
// }
type DType int
const (
DTypeBool DType = iota
DTypeUint8
DTypeUint16
DTypeUint32
DTypeUint64
DTypeInt8
DTypeInt16
DTypeInt32
DTypeInt64
DTypeFloat16
DTypeFloat32
DTypeFloat64
DTypeBfloat16
DTypeComplex64
)
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)

3
x/ml/backend/backend.go Normal file
View File

@@ -0,0 +1,3 @@
package backend
// _ "github.com/ollama/ollama/x/ml/backend/mlx"

View File

@@ -0,0 +1,57 @@
include(FetchContent)
set(MLX_C_BUILD_EXAMPLES OFF)
set(MLX_BUILD_GGUF OFF)
set(MLX_BUILD_SAFETENSORS ON)
function(set_target_output_directory _target)
if(TARGET ${_target})
set_target_properties(${_target} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
)
endif()
endfunction()
# Check for Metal support (macOS only)
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(NOT MLX_METAL_VERSION)
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
else()
# On Linux, disable Metal backend
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
set(MLX_BUILD_METAL OFF)
endif()
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
endif()
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
elseif(MLX_CUDA_ARCHITECTURES)
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG v0.4.1)
FetchContent_MakeAvailable(mlx-c)
set_target_output_directory(mlx)
set_target_output_directory(mlxc)

1278
x/ml/backend/mlx/mlx.go Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,314 @@
//go:build mlx
package mlx
import (
"log/slog"
"os"
"reflect"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
_ "github.com/ollama/ollama/x/model/models/gemma3"
)
func init() {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
}
func TestLoadModel(t *testing.T) {
dir := "/Users/daniel/Models/gemma-3-4b-it/"
b := &Backend{}
err := b.LoadSafeTensors(dir)
if err != nil {
t.Fatalf("load failed: %s", err)
}
}
func TestFromInts(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []int32{1, 2, 3, 4, 5, 6}
a := c.FromInts(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
}
func TestFromFloats(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
data := []float32{1, 2, 3, 4, 5, 6}
a := c.FromFloats(data, 2, 3)
slog.Info("", "array", a)
t.Log(a.ToString())
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
t.Fatalf("incorrect shape: %v", a.Shape())
}
res := a.Floats()
if !reflect.DeepEqual(res, data) {
t.Fatalf("incorrect results: %v", res)
}
}
func TestAdd(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
t3 := t1.Add(c, t2)
c.Compute(t3, exp)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp.Floats()) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestReshapeTranspose(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
c.Compute(t1)
t1f := t1.Floats()
exp := []float32{
0, 4, 8,
1, 5, 9,
2, 6, 10,
3, 7, 11,
12, 16, 20,
13, 17, 21,
14, 18, 22,
15, 19, 23,
}
if !reflect.DeepEqual(t1f, exp) {
t.Fatalf("incorrect results: %v", t1f)
}
}
func prod(vals ...int) int {
r := 1
for _, v := range vals {
r *= v
}
return r
}
func TestMatmul(t *testing.T) {
// TODO create scenarios...
b := &Backend{}
c := b.NewContext()
defer c.Close()
s1 := []int{1, 3, 2, 4}
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
s2 := []int{4, 2}
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
t3 := t1.Matmul(c, t2)
exp := []float32{
28, 34,
76, 98,
124, 162,
172, 226,
220, 290,
268, 354,
}
c.Compute(t3)
t3f := t3.Floats()
if !reflect.DeepEqual(t3f, exp) {
t.Fatalf("incorrect result: %v", t3f)
}
}
func TestRows(t *testing.T) {
b := &Backend{}
c := b.NewContext()
defer c.Close()
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
outputs := c.Zeros(ml.DTypeInt32, 1)
t2 := t1.TakeAxes(c, outputs, 1)
c.Forward(t1, t2).Compute(t1, t2)
t.Log(t1.ToString())
t.Log(t2.ToString())
f := t2.Floats()
t.Logf("Result: %v", f)
}
func TestCaching(t *testing.T) {
// Validate the caching algorithm
b := &Backend{}
c := b.NewContext()
defer c.Close()
batchSize := 3
headDim := 4
numKVHeads := 2
// Make cache twice the size of one test batch
cells := batchSize * 2
cellSize := numKVHeads * headDim
shape := []int{1, numKVHeads, batchSize, headDim}
stop := float32(1)
for _, x := range shape {
stop *= float32(x)
}
// Create the cache
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
// Input tensor
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
// Reshape to copy into the cache
/*
From MLX python/src/indexing.cpp mlx_scatter_args_array
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
auto up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
*/
numRows := 3
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
// Simulate cells 1,3,5 are available
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
cache.Scatter(c, indicies, up, axis)
c.Forward(cache)
// Cache should contain the data now
t.Log("Cache after put\n" + cache.ToString())
// Retrieve cache content and verify it matches
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
t1f := t1.Floats()
outf := out.Floats()
if !reflect.DeepEqual(t1f, outf) {
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
}
}
func TestGemma3(t *testing.T) {
// Why is the sky blue
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
limit := 50
// TODO generalize this
dir := "/Users/daniel/Models/gemma-3-4b-it/"
m, err := model.New(dir, ml.BackendParams{})
if err != nil {
t.Fatalf("unable to load model: %s", err)
}
b := m.Backend()
ctx := b.NewContext()
defer ctx.Close()
batch := input.Batch{
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
Positions: make([]int32, len(inputs)),
Sequences: make([]int, len(inputs)),
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
Offset: 0,
}
for i := range len(inputs) {
batch.Positions[i] = int32(i)
}
offset := len(inputs)
cache := m.Config().Cache
if cache != nil {
numSlots := 1
batchSize := 512
numCtx := 4096
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
err := cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
}
opts := api.DefaultOptions()
var grammar *sample.GrammarSampler
sampler := sample.NewSampler(
opts.Temperature,
opts.TopK,
opts.TopP,
opts.MinP,
opts.Seed,
grammar,
)
t.Log("Starting Forward pass loop")
pendingResponses := []string{}
for {
out, err := m.Forward(ctx, batch)
if err != nil {
t.Fatalf("failed forward pass: %s", err)
}
ctx.Forward(out)
outputs := out.Floats()
t.Logf("finished forward pass! length:%d", len(outputs))
// sample a token
logits := outputs
token, err := sampler.Sample(logits)
if err != nil {
t.Fatalf("unable to sample token: %s", err)
}
t.Logf("Sampled token: %v", token)
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
t.Log("hit EOS")
break
}
piece, err := m.(model.TextProcessor).Decode([]int32{token})
if err != nil {
t.Fatalf("unable to decode token: %s", err)
}
pendingResponses = append(pendingResponses, piece)
sequence := strings.Join(pendingResponses, "")
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
t.Logf("hit stop token: %v", stop)
break
}
t.Logf("RESULTS: %s", sequence)
batch = input.Batch{
Inputs: ctx.FromInts([]int32{token}, 1, 1),
Positions: make([]int32, 1),
Sequences: make([]int, 1),
Outputs: ctx.FromInts([]int32{0}, 1),
Offset: offset,
}
offset++
batch.Positions[0] = 0
err = cache.StartForward(ctx, batch, false)
if err != nil {
t.Fatalf("failed cache.StartForward: %s", err)
}
if offset > limit {
break
}
}
}

335
x/ml/backend/mlx/quant.go Normal file
View File

@@ -0,0 +1,335 @@
//go:build mlx
package mlx
/*
#include <stdio.h>
#include <string.h>
#include "mlx/c/array.h"
#include "mlx/c/ops.h"
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
void unpack_32_4(uint8_t* data, int8_t* dst) {
memset(dst, 0, 16);
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
if (j % 2 != 0) {
x <<= 4;
}
dst[j / 2] += x;
}
// Last 16 weights are in the higher bits
for (int j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] >> 4);
if (j % 2 != 0) {
x <<= 4;
}
dst[8 + j / 2] += x;
}
}
// Extracts (weight, scales, biases) from Q4_0 tensors.
// Data layout is: |16 bit scale|32 x 4bit weights|.
void extract_q4_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = -8 * scales[i];
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q4_1 tensors.
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
void extract_q4_1_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
scales[i] = *((float16_t*)data);
biases[i] = *((float16_t*)(data) + 1);
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q8_0 tensors.
// Data layout is: |16 bit scale|32 x 8bit weights|.
void extract_q8_0_data(
uint8_t* data,
mlx_array* weights_arr,
mlx_array* scales_arr,
mlx_array* biases_arr) {
const uint64_t weights_per_block = 32;
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
float16_t* scales = mlx_array_data_float16(*scales_arr);
float16_t* biases = mlx_array_data_float16(*biases_arr);
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
uint8_t* block_data = data + i * bytes_per_block;
scales[i] = *((float16_t*)block_data);
biases[i] = -128 * scales[i];
for (int64_t j = 0; j < weights_per_block; ++j) {
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
// Original data is in int8_t, so we add a bias of -128 and invert the
// first bit.
x ^= 1 << 7;
weights[i * weights_per_block + j] = x;
}
}
}
// Drived from ggml-quants.c
#define QK_K 256
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
uint16_t d; // super-block scale
} block_q6_K;
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
const int64_t nb = k / QK_K;
block_q6_K *x = (block_q6_K *)vx;
float16_t* y = (float16_t *)vy;
for (int i = 0; i < nb; i++) {
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
const uint8_t * restrict ql = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l + 0] = d * sc[is + 0] * q1;
y[l + 32] = d * sc[is + 2] * q2;
y[l + 64] = d * sc[is + 4] * q3;
y[l + 96] = d * sc[is + 6] * q4;
}
y += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
#define K_SCALE_SIZE 12
#define GGML_COMMON_AGGR_U
#define GGML_COMMON_AGGR_S
// 4-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
typedef struct {
union {
struct {
uint16_t d; // super-block scale for quantized scales
uint16_t dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR_S;
uint16_t dm;
} GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
} else {
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
block_q4_K *x = (block_q4_K *)vx;
float16_t* y = (float16_t *)vy;
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const uint8_t * q = x[i].qs;
float16_t d = 0.0;
memcpy(&d, &x[i].d, sizeof(d));
float16_t min = 0.0;
memcpy(&min, &x[i].dmin, sizeof(d));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
const float16_t d1 = d * sc; const float16_t m1 = min * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
const float16_t d2 = d * sc; const float16_t m2 = min * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
}
}
*/
import "C"
import (
"fmt"
"unsafe"
"github.com/x448/float16"
)
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
shape := append([]C.int{}, final_shape...)
var weights_per_byte C.int
if dtype == 2 || dtype == 3 {
weights_per_byte = 2
} else if dtype == 8 {
weights_per_byte = 1
} else {
return r, fmt.Errorf("unsupported tensor type %d", dtype)
}
weights_per_block := C.int(32)
if shape[len(shape)-1]%weights_per_block != 0 {
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
}
weights_shape := append([]C.int{}, shape...)
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
for i := range weights_shape {
w_nbytes *= weights_shape[i]
}
w_data := make([]byte, w_nbytes)
cbytes := C.CBytes(w_data)
defer C.free(cbytes)
weights := C.mlx_array_new_data(
cbytes,
&weights_shape[0],
C.int(len(weights_shape)),
C.MLX_UINT32,
)
// For scales and bias
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
for i := range shape {
sb_nbytes *= shape[i]
}
s_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(s_data)
defer C.free(cbytes)
scales := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
b_data := make([]byte, sb_nbytes)
cbytes = C.CBytes(b_data)
defer C.free(cbytes)
biases := C.mlx_array_new_data(
cbytes,
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
var bits C.int
switch dtype {
case 2:
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 3:
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 4
case 8:
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
bits = 8
}
groupSize := C.mlx_optional_int{value: 32, has_value: true}
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
C.mlx_dequantize(
&r,
weights,
scales,
biases,
groupSize,
bitsOpt,
nil, // TODO mode
dtypeOpt,
stream,
)
C.mlx_array_free(weights)
C.mlx_array_free(scales)
C.mlx_array_free(biases)
return r, nil
}
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
size := 1
for _, d := range shape {
size *= int(d)
}
fdata := make([]float16.Float16, size)
switch dtype {
case 14:
C.dequant_row_q6_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
case 12:
C.dequant_row_q4_K(
data,
unsafe.Pointer(&fdata[0]),
C.int(size),
)
default:
return r, fmt.Errorf("unsupported K quant")
}
r = C.mlx_array_new_data(
unsafe.Pointer(&fdata[0]),
&shape[0],
C.int(len(shape)),
C.MLX_FLOAT16,
)
return r, nil
}

643
x/ml/device.go Normal file
View File

@@ -0,0 +1,643 @@
package ml
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"hash/maphash"
"io"
"log/slog"
"math"
"net/http"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/logutil"
)
// GPULayers is a set of layers to be allocated on a single GPU
type GPULayers struct {
DeviceID
// Layers is a set of layer indicies to load
Layers []int
}
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
func (g GPULayers) FirstLayer() int {
if len(g.Layers) == 0 {
return math.MaxInt
}
first := g.Layers[0]
for i := 1; i < len(g.Layers); i++ {
if g.Layers[i] < first {
first = g.Layers[i]
}
}
return first
}
func (g GPULayers) String() string {
if len(g.Layers) == 0 {
return ""
}
slices.Sort(g.Layers)
contiguous := true
base := g.Layers[0]
for i := range g.Layers {
if g.Layers[i] != base+i {
contiguous = false
break
}
}
if contiguous {
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
} else {
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
}
}
// GPULayersList is a set of layer allocations across multiple GPUs
type GPULayersList []GPULayers
func (l GPULayersList) Len() int { return len(l) }
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
// Sort by the ordering of the layers offloaded
func (l GPULayersList) Less(i, j int) bool {
li := l[i].FirstLayer()
lj := l[j].FirstLayer()
return li < lj
}
func (l GPULayersList) String() string {
if l.Sum() > 0 {
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
} else {
return fmt.Sprintf("%v", []GPULayers(l))
}
}
// Sum is the total number of layers assigned across all GPUs
func (l GPULayersList) Sum() int {
var sum int
for _, g := range l {
sum += len(g.Layers)
}
return sum
}
var h maphash.Hash
// Hash is an identifier of this layer assignment
func (l GPULayersList) Hash() uint64 {
h.Reset()
for _, g := range l {
if len(g.Layers) > 0 {
h.WriteString(g.ID + g.Library)
for _, l := range g.Layers {
binary.Write(&h, binary.NativeEndian, int64(l))
}
}
}
return h.Sum64()
}
// ErrNoMem is returned when panicing due to insufficient memory. It includes
// the attempted memory allocation.
type ErrNoMem struct {
BackendMemory
}
func (e ErrNoMem) Error() string {
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
}
// Minimal unique device identification
type DeviceID struct {
// ID is an identifier for the device for matching with system
// management libraries. The ID is only unique for other devices
// using the same Library.
// This ID represents a "post filtered" view of the enumerated devices
// if the ID is numeric
ID string `json:"id"`
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
Library string `json:"backend,omitempty"`
}
// DeviceMemory provides a breakdown of the memory needed
// per device, such as a CPU or GPU.
type DeviceMemory struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string
// Weights is the per-layer memory needed for the model weights.
Weights []uint64
// Cache is the per-layer memory needed for the KV cache.
Cache []uint64
// Graph is the size of the compute graph. It is not per-layer.
Graph uint64
}
func sumMemory(mem []uint64) uint64 {
var sum uint64
for _, m := range mem {
sum += m
}
return sum
}
// Size returns the total size of the memory required by this device
func (m DeviceMemory) Size() uint64 {
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
}
func memoryPresent(mem []uint64) bool {
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
}
func (m DeviceMemory) LogValue() slog.Value {
var attrs []slog.Attr
if memoryPresent(m.Weights) {
attrs = append(attrs, slog.Any("Weights", m.Weights))
}
if memoryPresent(m.Cache) {
attrs = append(attrs, slog.Any("Cache", m.Cache))
}
if m.Graph != 0 {
attrs = append(attrs, slog.Any("Graph", m.Graph))
}
if len(attrs) > 0 && m.ID != "" {
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
}
return slog.GroupValue(attrs...)
}
// BackendMemory provides the amount of memory required to load the model
// per device based on the BackendParams. In some cases, not all required
// allocations will be known at this point. However, the size of the most recent
// allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress.
type BackendMemory struct {
// InputWeights are always located on the CPU and cannot be moved
InputWeights uint64
// CPU model components are located in system memory. This does not
// include unified memory allocated through the GPU.
CPU DeviceMemory
// GPU model components are located on one or more GPUs.
GPUs []DeviceMemory
}
func (m BackendMemory) LogValue() slog.Value {
var attrs []slog.Attr
if m.InputWeights != 0 {
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
}
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
for _, g := range m.GPUs {
attrs = append(attrs, slog.Any(g.Name, g))
}
return slog.GroupValue(attrs...)
}
// Log prints a high level summary of the memory
func (m BackendMemory) Log(level slog.Level) {
var total uint64
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := sumMemory(gpu.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := sumMemory(m.CPU.Cache); sum > 0 {
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
for _, gpu := range m.GPUs {
if sum := gpu.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
total += sum
}
}
if sum := m.CPU.Graph; sum > 0 {
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
total += sum
}
if total > 0 {
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
}
}
type DeviceInfo struct {
DeviceID
// Name is the name of the device as labeled by the backend. It
// may not be persistent across instances of the runner.
Name string `json:"name"`
// Description is the longer user-friendly identification of the device
Description string `json:"description"`
// FilterID is populated with the unfiltered device ID if a numeric ID is used
// so the device can be included.
FilterID string `json:"filter_id,omitempty"`
// Integrated is set true for integrated GPUs, false for Discrete GPUs
Integrated bool `json:"integration,omitempty"`
// PCIID is the bus, device and domain ID of the device for deduplication
// when discovered by multiple backends
PCIID string `json:"pci_id,omitempty"`
// TotalMemory is the total amount of memory the device can use for loading models
TotalMemory uint64 `json:"total_memory"`
// FreeMemory is the amount of memory currently available on the device for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// ComputeMajor is the major version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMajor int
// ComputeMinor is the minor version of capabilities of the device
// if unsupported by the backend, -1 will be returned
ComputeMinor int
// Driver Information
DriverMajor int `json:"driver_major,omitempty"`
DriverMinor int `json:"driver_minor,omitempty"`
// Where backends were loaded from
LibraryPath []string
}
type SystemInfo struct {
// ThreadCount is the optimal number of threads to use for inference
ThreadCount int `json:"threads,omitempty"`
// TotalMemory is the total amount of system memory
TotalMemory uint64 `json:"total_memory,omitempty"`
// FreeMemory is the amount of memory currently available on the system for loading models
FreeMemory uint64 `json:"free_memory,omitempty"`
// FreeSwap is the amount of system swap space reported as available
FreeSwap uint64 `json:"free_swap,omitempty"`
}
func (d DeviceInfo) Compute() string {
// AMD gfx is encoded into the major minor in hex form
if strings.EqualFold(d.Library, "ROCm") {
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
}
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
}
func (d DeviceInfo) Driver() string {
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
}
// MinimumMemory reports the amount of memory that should be set aside
// on the device for overhead (e.g. VRAM consumed by context structures independent
// of model allocations)
func (d DeviceInfo) MinimumMemory() uint64 {
if d.Library == "Metal" {
return 512 * format.MebiByte
}
return 457 * format.MebiByte
}
// Sort by Free Space.
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
type ByFreeMemory []DeviceInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool {
if a[i].Integrated && !a[j].Integrated {
return true
} else if !a[i].Integrated && a[j].Integrated {
return false
}
return a[i].FreeMemory < a[j].FreeMemory
}
// ByPerformance groups devices by similar speed
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
scores := []bool{}
for _, info := range l {
found := false
requested := info.Integrated
for i, score := range scores {
if score == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
scores = append(scores, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
resp := [][]DeviceInfo{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, requested)
resp = append(resp, []DeviceInfo{info})
}
}
return resp
}
func LibraryPaths(l []DeviceInfo) []string {
gpuLibs := []string{LibOllamaPath}
for _, gpu := range l {
for _, dir := range gpu.LibraryPath {
needed := true
for _, existing := range gpuLibs {
if dir == existing {
needed = false
break
}
}
if needed {
gpuLibs = append(gpuLibs, dir)
}
}
}
return gpuLibs
}
type DeviceComparison int
const (
UniqueDevice DeviceComparison = iota
SameBackendDevice // The device is the same, and the library/backend is the same
DuplicateDevice // The same physical device but different library/backend (overlapping device)
)
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
if a.PCIID != b.PCIID {
return UniqueDevice
}
// If PCIID is empty, we have to use ID + library for uniqueness
if a.PCIID == "" && a.DeviceID != b.DeviceID {
return UniqueDevice
}
if a.Library == b.Library {
return SameBackendDevice
}
return DuplicateDevice
}
// For a SameBackendDevice, return true if b is better than a
// e.g. newer GPU library version
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
aLib := a.LibraryPath[len(a.LibraryPath)-1]
bLib := b.LibraryPath[len(b.LibraryPath)-1]
if aLib == bLib {
return false
}
aLibSplit := strings.SplitN(aLib, "_", 2)
bLibSplit := strings.SplitN(bLib, "_", 2)
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
return false
}
if aLibSplit[0] != bLibSplit[0] {
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
return false
}
if aLibSplit[1] == bLibSplit[1] {
return false
}
cmp := []string{aLibSplit[1], bLibSplit[1]}
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
return cmp[0] == bLibSplit[1]
}
// For each GPU, check if it does NOT support flash attention
func FlashAttentionSupported(l []DeviceInfo) bool {
for _, gpu := range l {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
if !supportsFA {
return false
}
}
return true
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
if len(l) == 0 {
return nil
}
env := map[string]string{}
for _, d := range l {
d.updateVisibleDevicesEnv(env, mustFilter)
}
return env
}
// NeedsInitValidation returns true if the device in question has the potential
// to crash at inference time and requires deeper validation before we include
// it in the supported devices list.
func (d DeviceInfo) NeedsInitValidation() bool {
// ROCm: rocblas will crash on unsupported devices.
// CUDA: verify CC is supported by the version of the library
return d.Library == "ROCm" || d.Library == "CUDA"
}
// Set the init validation environment variable
func (d DeviceInfo) AddInitValidation(env map[string]string) {
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
}
// PreferredLibrary returns true if this library is preferred over the other input
// library
// Used to filter out Vulkan in favor of CUDA or ROCm
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
// TODO in the future if we find Vulkan is better than ROCm on some devices
// that implementation can live here.
if d.Library == "CUDA" || d.Library == "ROCm" {
return true
}
return false
}
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
var envVar string
switch d.Library {
case "ROCm":
// ROCm must be filtered as it can crash the runner on unsupported devices
envVar = "ROCR_VISIBLE_DEVICES"
if runtime.GOOS != "linux" {
envVar = "HIP_VISIBLE_DEVICES"
}
case "CUDA":
if !mustFilter {
// By default we try to avoid filtering CUDA devices because ROCm also
// looks at the CUDA env var, and gets confused in mixed vendor environments.
return
}
envVar = "CUDA_VISIBLE_DEVICES"
default:
// Vulkan is not filtered via env var, but via scheduling decisions
return
}
v, existing := env[envVar]
if existing {
v = v + ","
}
if d.FilterID != "" {
v = v + d.FilterID
} else {
v = v + d.ID
}
env[envVar] = v
}
type BaseRunner interface {
// GetPort returns the localhost port number the runner is running on
GetPort() int
// HasExited indicates if the runner is no longer running. This can be used during
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
HasExited() bool
}
type RunnerDiscovery interface {
BaseRunner
// GetDeviceInfos will perform a query of the underlying device libraries
// for device identification and free VRAM information
// During bootstrap scenarios, this routine may take seconds to complete
GetDeviceInfos(ctx context.Context) []DeviceInfo
}
type FilteredRunnerDiscovery interface {
RunnerDiscovery
// GetActiveDeviceIDs returns the filtered set of devices actively in
// use by this runner for running models. If the runner is a bootstrap runner, no devices
// will be active yet so no device IDs are returned.
// This routine will not query the underlying device and will return immediately
GetActiveDeviceIDs() []DeviceID
}
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
var moreDevices []DeviceInfo
port := runner.GetPort()
tick := time.Tick(10 * time.Millisecond)
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("failed to finish discovery before timeout")
case <-tick:
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
// slog.Warn("failed to send request", "error", err)
if runner.HasExited() {
return nil, fmt.Errorf("runner crashed")
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// old runner, fall back to bootstrapping model
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
slog.Warn("failed to read response", "error", err)
continue
}
if resp.StatusCode != 200 {
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
return nil, fmt.Errorf("runner error: %s", string(body))
}
if err := json.Unmarshal(body, &moreDevices); err != nil {
slog.Warn("unmarshal encode response", "error", err)
continue
}
return moreDevices, nil
}
}
}

103
x/ml/nn/attention.go Normal file
View File

@@ -0,0 +1,103 @@
package nn
import (
"fmt"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
)
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
ctx.Forward(key, value)
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
// panic("after cache get") //
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
// var mask ml.Tensor
if cache != nil {
key, value, _ = cache.Get(ctx)
}
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
// panic("after cache get") //
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if cache != nil {
// TODO what to do with vmla?
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
} else {
panic("else case not supported")
// TODO transpose shapes are wrong
// key = key.Transpose(ctx, 0, 2, 1, 3)
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
// kq := query.Matmul(ctx, key)
// kq = kq.Scale(ctx, scale)
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
// kq = kq.Softmax(ctx)
// kqv := kq.Matmul(ctx, value)
// if vmla != nil {
// kqv = kqv.Matmul(ctx, vmla)
// }
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
}
}

30
x/ml/nn/convolution.go Normal file
View File

@@ -0,0 +1,30 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Conv2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
if m.Bias != nil {
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
}
return t
}
type Conv3D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}

11
x/ml/nn/embedding.go Normal file
View File

@@ -0,0 +1,11 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Embedding struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
return m.Weight.TakeAxes(ctx, hiddenState, 0)
}

32
x/ml/nn/linear.go Normal file
View File

@@ -0,0 +1,32 @@
package nn
import "github.com/ollama/ollama/x/ml"
type Linear struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
if m.Bias != nil {
t = t.Add(ctx, m.Bias)
}
return t
}
type LinearBatch struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
panic("not yet ported")
// t = m.Weight.MulmatID(ctx, t, indices)
// if m.Bias != nil {
// t = t.AddID(ctx, m.Bias, indices)
// }
// return t
}

29
x/ml/nn/normalization.go Normal file
View File

@@ -0,0 +1,29 @@
package nn
import (
"github.com/ollama/ollama/x/ml"
)
type LayerNorm struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
}
type RMSNorm struct {
Weight ml.Tensor `gguf:"weight"`
}
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
// slog.Info("RMSNorm", "eps", eps)
// fmt.Fprintln(os.Stderr, t.ToString())
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
// TODO this is probably model specific, not generalized...
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
return t.RMSNorm(ctx, w, eps)
}

View File

@@ -0,0 +1,41 @@
package pooling
import (
"github.com/ollama/ollama/x/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
)
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
// case TypeMean:
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
// case TypeCLS:
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
// case TypeLast:
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
default:
panic("unknown pooling type")
}
}

72
x/ml/nn/rope/rope.go Normal file
View File

@@ -0,0 +1,72 @@
package rope
import "github.com/ollama/ollama/x/ml"
// Options contains optional parameters for RoPE function
type Options struct {
Type int
Factors ml.Tensor
// YaRN options
YaRN struct {
OriginalContextLength int
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
}
// MRoPE options
MRoPE struct {
Sections []int
}
}
// WithTypeNeoX sets RoPE type to NeoX
func WithTypeNeoX() func(*Options) {
return func(opts *Options) {
opts.Type = 2
}
}
// WithFactors sets custom rope factors
func WithFactors(factors ml.Tensor) func(*Options) {
return func(opts *Options) {
if factors != nil {
opts.Factors = factors
}
}
}
// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
return func(opts *Options) {
opts.YaRN.OriginalContextLength = n
}
}
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.ExtrapolationFactor = extrapolationFactor
}
}
func WithAttentionFactor(attentionFactor float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.AttentionFactor = attentionFactor
}
}
func WithMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1 << 3
opts.MRoPE.Sections = sections
}
}
func WithInterleaveMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1<<3 | 1<<5
opts.MRoPE.Sections = sections
}
}

56
x/ml/path.go Normal file
View File

@@ -0,0 +1,56 @@
package ml
import (
"os"
"path/filepath"
"runtime"
)
// LibPath is a path to lookup dynamic libraries
// in development it's usually 'build/lib/ollama'
// in distribution builds it's 'lib/ollama' on Windows
// '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as
// 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string {
exe, err := os.Executable()
if err != nil {
return ""
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var libPath string
switch runtime.GOOS {
case "windows":
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
case "linux":
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
case "darwin":
libPath = filepath.Dir(exe)
}
cwd, err := os.Getwd()
if err != nil {
return ""
}
paths := []string{
libPath,
// build paths for development
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
filepath.Join(cwd, "build", "lib", "ollama"),
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p
}
}
return filepath.Dir(exe)
}()

282
x/model/bytepairencoding.go Normal file
View File

@@ -0,0 +1,282 @@
package model
import (
"cmp"
"fmt"
"iter"
"log/slog"
"slices"
"strings"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
return bpe.vocab
}
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.vocab.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
}
}
}
}
}
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
type lazyIdsString struct {
ids []int32
}
func (l lazyIdsString) LogValue() slog.Value {
return slog.AnyValue(fmt.Sprint(l.ids))
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
return sb.String(), nil
}

View File

@@ -0,0 +1,322 @@
package model
import (
"bufio"
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
t.Fatal(err)
}
types := make([]int32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token
types[id] = 1
}
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
tokens = append(tokens, token) //nolint:makezero
types = append(types, 3) //nolint:makezero
vocab[token] = int32(len(vocab))
}
}
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
merges := make([]string, 0, 50000)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "#") {
merges = append(merges, scanner.Text())
}
}
return NewBytePairEncoding(
&Vocabulary{
Values: tokens,
Types: types,
Merges: merges,
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
}
func TestLlama(t *testing.T) {
tokenizer := llama(t)
t.Run("simple", func(t *testing.T) {
t.Parallel()
ids, err := tokenizer.Encode("hello world", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
s, err := tokenizer.Decode([]int32{15339, 1917})
if err != nil {
t.Fatal(err)
}
if s != "hello world" {
t.Errorf("got %q, want hello world", s)
}
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
t.Run("simple repeated", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
strings.Repeat("0", 1): {15},
strings.Repeat("0", 2): {410},
strings.Repeat("0", 3): {931},
strings.Repeat("0", 4): {931, 15},
strings.Repeat("0", 5): {931, 410},
strings.Repeat("0", 6): {931, 931},
strings.Repeat("0", 7): {931, 931, 15},
strings.Repeat("0", 8): {931, 931, 410},
strings.Repeat("0", 9): {931, 931, 931},
strings.Repeat("0", 10): {931, 931, 931, 15},
strings.Repeat("0", 11): {931, 931, 931, 410},
strings.Repeat("0", 12): {931, 931, 931, 931},
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
}
}
})
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Error(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
})
t.Run("special", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("split", func(t *testing.T) {
t.Parallel()
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
"Hello World": {"Hello", " ", " World"},
"Hello\nWorld": {"Hello", "\n", "World"},
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
t.Parallel()
for b := 0x00; b <= 0xFF; b++ {
input := string(rune(b))
ids, err := tokenizer.Encode(input, false)
if err != nil {
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
continue
}
if b == 0x00 {
if len(decoded) != 0 {
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
}
continue
}
if decoded != input {
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
}
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
if err != nil {
b.Fatal(err)
}
for i := range 8 {
n := min(int(math.Pow10(i)), len(bts))
bts := bts[:n]
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Decode(ids)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
slices.Collect(tokenizer.split(string(bts)))
}
})
}
}
func TestSplit(t *testing.T) {
cases := []struct {
name string
patterns,
want []string
}{
{
name: "default",
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
},
{
name: "unicode",
patterns: []string{
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
},
{
name: "individual digits",
patterns: []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
}
}

76
x/model/input/input.go Normal file
View File

@@ -0,0 +1,76 @@
package input
import "github.com/ollama/ollama/x/ml"
// Multimodal is a multimodal embedding or a component of one.
// For example, it could be a row of an image that can be processed
// independently.
type Multimodal struct {
// Tensor is the embedding data. Implementations may chose what to
// store here or it may be nil if not needed. However, any ml.Tensor
// objects must be stored here and not in Data.
Tensor ml.Tensor
// Data is implementation-specific opaque data, such as metadata on how
// to layout Tensor. It may be nil if not needed. It may also store larger
// objects such as complete images if they are to be processed later.
Data any
}
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is represents a non-text element such as an
// image (or part of one if the image can be processed in pieces).
// It may be used either together with Token or on its own.
Multimodal []Multimodal
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
// SameBatch forces the following number of tokens to be processed
// in a single batch, breaking and extending batches as needed.
// Useful for things like images that must be processed in one
// shot.
SameBatch int
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal []Multimodal
}
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs ml.Tensor
// TODO maybe not the optimal way to handle this
// Offset of final tensor in the final batch
Offset int
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
}

333
x/model/model.go Normal file
View File

@@ -0,0 +1,333 @@
package model
import (
"errors"
"fmt"
_ "image/jpeg"
_ "image/png"
"log/slog"
"os"
"reflect"
"strconv"
"strings"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
_ "github.com/ollama/ollama/x/ml/backend"
"github.com/ollama/ollama/x/ml/nn/pooling"
"github.com/ollama/ollama/x/model/input"
)
var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend
Config() config
}
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is one or more tensors, each with optional model-specific
// opaque metadata. Typically, the tensors might be views into an embedding
// with each view representing a chunk of data that can be processed independently
// in different batches.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]*input.Input) ([]*input.Input, error)
}
// Base implements the common fields and methods for all models
type Base struct {
b ml.Backend
config
}
type config struct {
Cache kvcache.Cache
}
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
}
func (m *Base) Config() config {
return m.config
}
var models = make(map[string]func(fs.Config) (Model, error))
// Register registers a model constructor for the given architecture
func Register(name string, f func(fs.Config) (Model, error)) {
if _, ok := models[name]; ok {
panic("model: model already registered")
}
models[name] = f
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
b, err := ml.NewBackend(modelPath, params)
if err != nil {
return nil, err
}
m, err := modelForArch(b.Config())
if err != nil {
return nil, err
}
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
m, err := modelForArch(meta.KV())
if err != nil {
return nil, err
}
tp, ok := m.(TextProcessor)
if !ok {
return nil, ErrUnsupportedTokenizer
}
return tp, nil
}
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, ErrUnsupportedModel
}
return f(c)
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
if t.Kind() == reflect.Struct {
allNil := true
for i := range t.NumField() {
tt := t.Field(i).Type
vv := v.Field(i)
if !vv.CanSet() {
continue
}
// make a copy
tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, parseTag(tag))
}
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 {
var names []string
if tags[0].name != "" {
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
names = append(names, prefix+n+suffix)
}
}
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
if len(names) == 0 {
// current tag has no name, use child names only
fullNames = append(fullNames, childNames...)
} else if len(childNames) == 0 {
// current tag has names but no children, create branches for each name
for _, name := range names {
fullNames = append(fullNames, []string{name})
}
} else {
// merge each name with each child
for _, name := range names {
for _, childName := range childNames {
fullNames = append(fullNames, append([]string{name}, childName...))
}
}
}
}
return fullNames
}
names := fn(tagsCopy, "", "")
for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
logutil.Trace("found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor))
break
}
}
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
setPointer(base, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
for i := range vv.Len() {
vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
} else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
}
}
}
if !canNil(tt) || !vv.IsNil() {
allNil = false
}
}
if allNil {
return reflect.Zero(t)
}
}
return v
}
func setPointer(base Base, v reflect.Value, tags []Tag) {
vv := v
if v.Kind() == reflect.Interface {
if v.IsNil() {
return
}
vv = vv.Elem()
}
vv = reflect.Indirect(vv)
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}
if f := populateFields(base, vv, tags...); f.CanAddr() {
v.Set(f.Addr())
}
}
type Tag struct {
name,
// prefix and suffix are applied to child tags
prefix,
suffix string
alternatives []string
}
func parseTag(s string) (tag Tag) {
parts := strings.Split(s, ",")
if len(parts) > 0 {
tag.name = parts[0]
for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
// elevate alternative to primary if no primary given
tag.name = value
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
} else if ok {
tag.alternatives = append(tag.alternatives, value)
}
if value, ok := strings.CutPrefix(part, "pre:"); ok {
tag.prefix = value
}
if value, ok := strings.CutPrefix(part, "suf:"); ok {
tag.suffix = value
}
}
}
return
}
func canNil(t reflect.Type) bool {
return t.Kind() == reflect.Chan ||
t.Kind() == reflect.Func ||
t.Kind() == reflect.Interface ||
t.Kind() == reflect.Map ||
t.Kind() == reflect.Pointer ||
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, false)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
}
ctx.Forward(t)
return t, nil
}

View File

@@ -0,0 +1,58 @@
//go:build mlx
package gemma3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/ml/nn/pooling"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
)
type embedModel struct {
model.Base
model.SentencePiece
*TextModel
poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: newTextModel(c),
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
return m, nil
}

View File

@@ -0,0 +1,157 @@
//go:build mlx
package gemma3
import (
"bytes"
"image"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/model/input"
)
type Model struct {
model.Base
model.SentencePiece
*VisionModel `gguf:"vision_tower.vision_model"`
*TextModel `gguf:"language_model.model"`
*MultiModalProjector `gguf:"multi_modal_projector"`
ImageProcessor
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight
tokensPerImage int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false))
return visionOutputs
}
func New(c fs.Config) (model.Model, error) {
// slog.Info("XXX Config", "c", c)
m := Model{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
MultiModalProjector: &MultiModalProjector{
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
},
}
// slidingWindowLen := int32(c.Uint("attention.sliding_window"))
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
// TODO need to implement sliding window...
m.Cache = kvcache.NewMLXCausalCache()
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues := ctx.Input().FromFloats(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal[0].Tensor
result = append(result,
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
&input.Input{Token: 255999}, // "<start_of_image>""
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result,
&input.Input{Token: 256000}, // <end_of_image>
&input.Input{Token: 108}, // "\n\n"
)
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("gemma3", New)
model.Register("gemma3_embed", newEmbedModel)
}

View File

@@ -0,0 +1,211 @@
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/kvcache"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
"github.com/ollama/ollama/x/model/input"
)
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
Layers []TextLayer `gguf:"layers"`
OutputNorm *nn.RMSNorm `gguf:"norm"`
Output *nn.Linear `gguf:"embed_tokens"`
*TextConfig
}
const (
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 62
)
// const (
// cacheTypeSWA = iota
// cacheTypeCausal
// )
func newTextModel(c fs.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
ropeScale: 1, // 1 - default is 1, implied in python code
// vocabSize: vocabSize, // 262144
// attnValLen: int(c.Uint("attention.value_length", 256)), //256
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
},
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"q_proj"`
QueryNorm *nn.RMSNorm `gguf:"q_norm"`
Key *nn.Linear `gguf:"k_proj"`
KeyNorm *nn.RMSNorm `gguf:"k_norm"`
Value *nn.Linear `gguf:"v_proj"`
Output *nn.Linear `gguf:"o_proj"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
B := hiddenState.Dim(0)
L := hiddenState.Dim(1)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
k := sa.Key.Forward(ctx, hiddenState)
v := sa.Value.Forward(ctx, hiddenState)
q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
traditional := false
q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
// TODO - this is wrong somehow so commenting out
// if opts.largeModelScaling {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
// } else {
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
// }
scaleFactor := math.Pow(256, -0.5)
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// ropeBase := m.TextConfig.ropeLocalBase
// if (layer+1)%gemmaGlobalCacheCount == 0 {
// ropeBase = m.TextConfig.ropeGlobalBase
// }
// q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
panic("not yet implemented")
// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"up_proj"`
Down *nn.Linear `gguf:"down_proj"`
Gate *nn.Linear `gguf:"gate_proj"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
SelfAttention *TextSelfAttention `gguf:"self_attn"`
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
MLP *TextMLP `gguf:"mlp"`
PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
residual = residual.TakeAxes(ctx, outputs, 1)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
// var except []int
// for _, image := range batch.Multimodal {
// visionOutputs := image.Multimodal[0].Tensor
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
// []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
// []int{image.Index * hiddenState.Stride(1)}, 0)))
// for i := range visionOutputs.Dim(1) {
// except = append(except, image.Index+i)
// }
// }
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
if cache != nil {
// cacheType := cacheTypeSWA
// if (i+1)%gemmaGlobalCacheCount == 0 {
// cacheType = cacheTypeCausal
// }
cache.SetLayer(i)
// TODO this needs to come back
// wc := cache.(*kvcache.WrapperCache)
// wc.SetLayerType(cacheType)
// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
// causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
// }
}
var offset int
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
offset = batch.Offset
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}

View File

@@ -0,0 +1,121 @@
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/ollama/ollama/x/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"self_attn.q_proj"`
Key *nn.Linear `gguf:"self_attn.k_proj"`
Value *nn.Linear `gguf:"self_attn.v_proj"`
Output *nn.Linear `gguf:"self_attn.out_proj"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Layers []VisionEncoderLayer `gguf:"encoder.layers"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32)
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}

View File

@@ -0,0 +1,60 @@
//go:build mlx
package gemma3
import (
"image"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
var pixelVals, rVals, gVals, bVals []float32
bounds := img.Bounds()
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil
}

3
x/model/models/models.go Normal file
View File

@@ -0,0 +1,3 @@
package models
// _ "github.com/ollama/ollama/x/model/models/gemma3"

249
x/model/sentencepiece.go Normal file
View File

@@ -0,0 +1,249 @@
package model
import (
"container/heap"
"fmt"
"log/slog"
"strconv"
"strings"
"github.com/ollama/ollama/logutil"
)
const spmWhitespaceSep = "▁"
type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
if id := spm.vocab.Encode(text); id >= 0 {
ids = append(ids, id)
continue
}
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(q, pair)
}
}
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
heap.Push(q, pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
heap.Push(q, pair)
}
}
for _, merge := range merges {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
if id >= 0 {
ids = append(ids, id)
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
result = append(result, unknownID)
} else {
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
}
ids = append(ids, result...)
}
}
}
if addSpecial {
ids = spm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err)
}
if err := sb.WriteByte(byte(byteVal)); err != nil {
return "", err
}
} else {
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
}
logutil.Trace("decoded", "ids", ids, "string", sb.String())
return sb.String(), nil
}

View File

@@ -0,0 +1,172 @@
package model
import (
"log/slog"
"os"
"path/filepath"
"slices"
"testing"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}
var spm sentencepiece.ModelProto
if err := proto.Unmarshal(bts, &spm); err != nil {
t.Fatal(err)
}
var v Vocabulary
for _, piece := range spm.GetPieces() {
v.Values = append(v.Values, piece.GetPiece())
v.Scores = append(v.Scores, piece.GetScore())
switch t := piece.GetType(); t {
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, int32(t))
default:
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
v.Types = append(v.Types, tt)
}
}
return NewSentencePiece(&v)
}
func TestSentencePieceEncode(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
tokenizer := loadSentencePieceVocab(t)
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
"你好",
"Hello 你好 world!",
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
"Numbers and symbols: 123456789 +- */",
"Special tokens: <bos> text <eos>",
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Fatal(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q [%#v]", got, want, ids)
}
}
})
t.Run("special tokens", func(t *testing.T) {
type candidate struct {
token string
ids []int32
}
cases := []candidate{
{"<bos>", []int32{2}},
{"<eos>", []int32{1}},
}
for _, want := range cases {
ids, err := tokenizer.Encode(want.token, true)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ids, want.ids) {
t.Errorf("got %#v, want %#v", ids, want.ids)
}
}
})
}
func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
"<0xEA>",
"<0x41>",
"<0xC3>",
"<0xA3>",
},
Types: []int32{
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
},
Scores: []float32{0, 0, 0, 0, 0},
}
spm := NewSentencePiece(vocab)
tests := []struct {
name string
ids []int32
expected string
}{
{
name: "single byte token",
ids: []int32{1},
expected: "\xea",
},
{
name: "ASCII byte token",
ids: []int32{2},
expected: "A",
},
{
name: "multiple byte tokens forming UTF-8 character",
ids: []int32{3, 4},
expected: "ã",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := spm.Decode(tt.ids)
if err != nil {
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
}
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}

17
x/model/textprocessor.go Normal file
View File

@@ -0,0 +1,17 @@
package model
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}

112
x/model/vocabulary.go Normal file
View File

@@ -0,0 +1,112 @@
package model
import (
"log/slog"
"slices"
"sync"
)
type Special int32
const (
SpecialBOS Special = iota
SpecialEOS
)
type Vocabulary struct {
Values []string
Types []int32
Scores []float32
Merges []string
BOS, EOS []int32
AddBOS, AddEOS bool
specialOnce sync.Once
special []string
valuesOnce sync.Once
values map[string]int32
mergeOnce sync.Once
merge map[string]int32
}
func (v *Vocabulary) Is(id int32, special Special) bool {
switch special {
case SpecialBOS:
return slices.Contains(v.BOS, id)
case SpecialEOS:
return slices.Contains(v.EOS, id)
default:
return false
}
}
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
if v.AddBOS && len(v.BOS) > 0 {
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
}
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
ids = append([]int32{v.BOS[0]}, ids...)
}
if v.AddEOS && len(v.EOS) > 0 {
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
}
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
ids = append(ids, v.EOS[0])
}
return ids
}
func (v *Vocabulary) Encode(s string) int32 {
v.valuesOnce.Do(func() {
v.values = make(map[string]int32, len(v.Values))
for i, value := range v.Values {
v.values[value] = int32(i)
}
})
if id, ok := v.values[s]; ok {
return id
}
return -1
}
func (v *Vocabulary) Decode(id int32) string {
return v.Values[id]
}
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
v.special = append(v.special, v.Values[i])
}
}
})
return v.special
}
func (v *Vocabulary) Merge(left, right string) int {
v.mergeOnce.Do(func() {
v.merge = make(map[string]int32, len(v.Merges))
for i, merge := range v.Merges {
v.merge[merge] = int32(i)
}
})
if id, ok := v.merge[left+" "+right]; ok {
return int(id)
}
return -1
}

107
x/model/vocabulary_test.go Normal file
View File

@@ -0,0 +1,107 @@
package model
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func TestSpecialVocabulary(t *testing.T) {
vocab := &Vocabulary{
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
}
specialVocab := vocab.SpecialVocabulary()
if len(specialVocab) != 4 {
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
}
}
func TestAddSpecialVocabulary(t *testing.T) {
cases := []struct {
name string
vocab *Vocabulary
input []int32
want []int32
}{
{
name: "add bos",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{2, 3, 4},
want: []int32{0, 2, 3, 4},
},
{
// TODO(mxyng): this is to match previous behaviour
name: "add bos when already present",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{0, 2, 3, 4},
want: []int32{0, 0, 2, 3, 4},
},
{
name: "add eos",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: false,
AddEOS: true,
},
input: []int32{2, 3, 4},
want: []int32{2, 3, 4, 1},
},
{
// TODO(mxyng): this is to match previous behaviour
name: "add eos when already present",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: false,
AddEOS: true,
},
input: []int32{2, 3, 4, 1},
want: []int32{2, 3, 4, 1, 1},
},
{
name: "add both",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: true,
},
input: []int32{2, 3, 4},
want: []int32{0, 2, 3, 4, 1},
},
{
name: "add bos to empty inputs",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{},
want: []int32{0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := tt.vocab.addSpecials(tt.input)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("no match (-want +got):\n%s", diff)
}
})
}
}

171
x/model/wordpiece.go Normal file
View File

@@ -0,0 +1,171 @@
package model
import (
"fmt"
"iter"
"strings"
"unicode"
"github.com/ollama/ollama/logutil"
)
type WordPiece struct {
vocab *Vocabulary
lowercase bool
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
// this differs from original word piece which uses "##" to indicate subwords.
const ggmlPrefix = "▁"
var wordPieceReplacer = strings.NewReplacer(
" .", ".",
" ?", "?",
" !", "!",
" ,", ",",
" ' ", "'",
" n't", "n't",
" 'm", "'m",
" do not", " don't",
" 's", "'s",
" 've", "'ve",
" 're", "'re",
)
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
if id < 0 || int(id) >= len(wpm.vocab.Values) {
return "", fmt.Errorf("invalid token id: %d", id)
}
var separator string
piece := wpm.vocab.Values[id]
if i > 0 &&
(strings.HasPrefix(piece, ggmlPrefix) ||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
separator = " "
}
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
}
return sb.String(), nil
}
// words splits a string into words, treating CJK characters as separate words.
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
func (wpm WordPiece) words(s string) iter.Seq[string] {
return func(yield func(string) bool) {
runes := make([]rune, 0, len(s)*3)
for _, r := range s {
switch {
case r >= 0x4E00 && r <= 0x9FFF,
r >= 0x3400 && r <= 0x4DBF,
r >= 0x20000 && r <= 0x2A6DF,
r >= 0x2A700 && r <= 0x2B73F,
r >= 0x2B740 && r <= 0x2B81F,
r >= 0x2B820 && r <= 0x2CEAF,
r >= 0xF900 && r <= 0xFAFF,
r >= 0x2F800 && r <= 0x2FA1F:
runes = append(runes, ' ', r, ' ')
default:
runes = append(runes, r)
}
}
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
// split on but keep punctuation
var start int
for start < len(w) {
end := strings.IndexFunc(w[start:], unicode.IsPunct)
if end < 0 {
end = len(w) - start
} else if end == 0 {
end = 1
}
if !yield(w[start : start+end]) {
return
}
start += end
}
}
}
}
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// TODO: use [UNK] from config
unk := wpm.vocab.Encode("[UNK]")
for word := range wpm.words(s) {
var start int
var pieces []int32
for start < len(word) {
end := len(word)
var piece int32
for start < end {
subword := word[start:end]
if start == 0 {
subword = ggmlPrefix + subword
}
if wpm.lowercase {
subword = strings.ToLower(subword)
}
piece = wpm.vocab.Encode(subword)
if piece >= 0 {
break
}
end--
}
if piece < 0 {
// Unknown token
pieces = pieces[:0]
break
}
pieces = append(pieces, piece)
start = end
}
if len(pieces) > 0 {
ids = append(ids, pieces...)
} else {
ids = append(ids, unk)
}
}
if addSpecial {
ids = wpm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{
vocab: vocab,
lowercase: lowercase,
}
}

53
x/model/wordpiece_test.go Normal file
View File

@@ -0,0 +1,53 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
},
true, // lowercase
)
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}