Compare commits
2 Commits
jmorganca/
...
parth/decr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b2abfb433 | ||
|
|
805ed4644c |
@@ -190,7 +190,7 @@ if(MLX_ENGINE)
|
|||||||
install(TARGETS mlx mlxc
|
install(TARGETS mlx mlxc
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
|
|||||||
41
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
|||||||
|
|
||||||
## Model library
|
## Model library
|
||||||
|
|
||||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||||
|
|
||||||
Here are some example models that can be downloaded:
|
Here are some example models that can be downloaded:
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||||
@@ -260,38 +260,6 @@ Finally, in a separate shell, run a model:
|
|||||||
./ollama run llama3.2
|
./ollama run llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
## Building with MLX (experimental)
|
|
||||||
|
|
||||||
First build the MLX libraries:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cmake --preset MLX
|
|
||||||
cmake --build --preset MLX --parallel
|
|
||||||
cmake --install build --component MLX
|
|
||||||
```
|
|
||||||
|
|
||||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
|
||||||
|
|
||||||
```shell
|
|
||||||
go build -tags mlx -o ollama-mlx .
|
|
||||||
```
|
|
||||||
|
|
||||||
Finally, start the server:
|
|
||||||
|
|
||||||
```
|
|
||||||
./ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
### Building MLX with CUDA
|
|
||||||
|
|
||||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cmake --preset 'MLX CUDA 13'
|
|
||||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
|
||||||
cmake --install build --component MLX
|
|
||||||
```
|
|
||||||
|
|
||||||
## REST API
|
## REST API
|
||||||
|
|
||||||
Ollama has a REST API for running and managing models.
|
Ollama has a REST API for running and managing models.
|
||||||
@@ -322,7 +290,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
### Web & Desktop
|
### Web & Desktop
|
||||||
|
|
||||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
|
||||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||||
@@ -526,7 +493,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
### Database
|
### Database
|
||||||
|
|
||||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||||
@@ -669,7 +636,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
### Observability
|
### Observability
|
||||||
|
|
||||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||||
@@ -678,5 +644,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||||
|
|
||||||
### Security
|
### Security
|
||||||
|
|
||||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ extern NSString *SystemWidePath;
|
|||||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||||
@property(assign, nonatomic) BOOL updateAvailable;
|
@property(assign, nonatomic) BOOL updateAvailable;
|
||||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
@implementation AppDelegate
|
@implementation AppDelegate
|
||||||
@@ -41,13 +40,6 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
|||||||
}
|
}
|
||||||
|
|
||||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||||
// Register for system shutdown/restart notification so we can allow termination
|
|
||||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
|
||||||
addObserver:self
|
|
||||||
selector:@selector(systemWillPowerOff:)
|
|
||||||
name:NSWorkspaceWillPowerOffNotification
|
|
||||||
object:nil];
|
|
||||||
|
|
||||||
// if we're in development mode, set the app icon
|
// if we're in development mode, set the app icon
|
||||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||||
if (![bundlePath hasSuffix:@".app"]) {
|
if (![bundlePath hasSuffix:@".app"]) {
|
||||||
@@ -286,18 +278,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
|||||||
[NSApp activateIgnoringOtherApps:YES];
|
[NSApp activateIgnoringOtherApps:YES];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
|
||||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
|
||||||
// The system will call applicationShouldTerminate: after posting this notification.
|
|
||||||
self.systemShutdownInProgress = YES;
|
|
||||||
}
|
|
||||||
|
|
||||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||||
// Allow termination if the system is shutting down or restarting
|
|
||||||
if (self.systemShutdownInProgress) {
|
|
||||||
return NSTerminateNow;
|
|
||||||
}
|
|
||||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
|
||||||
[NSApp hide:nil];
|
[NSApp hide:nil];
|
||||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||||
return NSTerminateCancel;
|
return NSTerminateCancel;
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
Prompt: ">>> ",
|
Prompt: ">>> ",
|
||||||
AltPrompt: "... ",
|
AltPrompt: "... ",
|
||||||
Placeholder: "Send a message (/? for help)",
|
Placeholder: "Send a message (/? for help)",
|
||||||
AltPlaceholder: "Press Enter to send",
|
AltPlaceholder: `Use """ to end multi-line input`,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ ollama pull glm-4.7:cloud
|
|||||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||||
```
|
```
|
||||||
@@ -248,13 +247,12 @@ curl -X POST http://localhost:11434/v1/messages \
|
|||||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||||
```
|
```
|
||||||
|
|
||||||
Or set the environment variables in your shell profile:
|
Or set the environment variables in your shell profile:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=ollama
|
export ANTHROPIC_API_KEY=ollama
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
|||||||
import { Ollama } from "ollama";
|
import { Ollama } from "ollama";
|
||||||
|
|
||||||
const client = new Ollama();
|
const client = new Ollama();
|
||||||
const results = await client.webSearch("what is ollama?");
|
const results = await client.webSearch({ query: "what is ollama?" });
|
||||||
console.log(JSON.stringify(results, null, 2));
|
console.log(JSON.stringify(results, null, 2));
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
|||||||
import { Ollama } from "ollama";
|
import { Ollama } from "ollama";
|
||||||
|
|
||||||
const client = new Ollama();
|
const client = new Ollama();
|
||||||
const fetchResult = await client.webFetch("https://ollama.com");
|
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||||
console.log(JSON.stringify(fetchResult, null, 2));
|
console.log(JSON.stringify(fetchResult, null, 2));
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -111,9 +111,7 @@
|
|||||||
"/integrations/zed",
|
"/integrations/zed",
|
||||||
"/integrations/roo-code",
|
"/integrations/roo-code",
|
||||||
"/integrations/n8n",
|
"/integrations/n8n",
|
||||||
"/integrations/xcode",
|
"/integrations/xcode"
|
||||||
"/integrations/onyx",
|
|
||||||
"/integrations/marimo"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 4096 tokens.
|
By default, Ollama uses a context window size of 2048 tokens.
|
||||||
|
|
||||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||||
|
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 174 KiB |
|
Before Width: | Height: | Size: 80 KiB |
|
Before Width: | Height: | Size: 230 KiB |
|
Before Width: | Height: | Size: 178 KiB |
|
Before Width: | Height: | Size: 186 KiB |
|
Before Width: | Height: | Size: 100 KiB |
|
Before Width: | Height: | Size: 306 KiB |
|
Before Width: | Height: | Size: 300 KiB |
|
Before Width: | Height: | Size: 211 KiB |
@@ -25,7 +25,6 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
|||||||
1. Set the environment variables:
|
1. Set the environment variables:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=ollama
|
export ANTHROPIC_API_KEY=ollama
|
||||||
```
|
```
|
||||||
@@ -39,7 +38,7 @@ claude --model qwen3-coder
|
|||||||
Or run with environment variables inline:
|
Or run with environment variables inline:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||||
```
|
```
|
||||||
|
|
||||||
## Connecting to ollama.com
|
## Connecting to ollama.com
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
---
|
|
||||||
title: marimo
|
|
||||||
---
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
|
||||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
|
||||||
|
|
||||||
```
|
|
||||||
uvx marimo edit --sandbox notebook.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage with Ollama
|
|
||||||
|
|
||||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
|
||||||
you can find and configure Ollama as an AI provider. For local use you
|
|
||||||
would typically point the base url to `http://localhost:11434/v1`.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-settings.png"
|
|
||||||
alt="Ollama settings in marimo"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-models.png"
|
|
||||||
alt="Selecting an Ollama model"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-add-model.png"
|
|
||||||
alt="Adding a new Ollama model"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-chat.png"
|
|
||||||
alt="Configure code completion"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-code-completion.png"
|
|
||||||
alt="Configure code completion"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
## Connecting to ollama.com
|
|
||||||
|
|
||||||
1. Sign in to ollama cloud via `ollama signin`
|
|
||||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
|
||||||
3. You can now refer to this model in marimo!
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
---
|
|
||||||
title: Onyx
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
|
||||||
- Creating custom Agents
|
|
||||||
- Web search
|
|
||||||
- Deep Research
|
|
||||||
- RAG over uploaded documents and connected apps
|
|
||||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
|
||||||
- MCP and OpenAPI Actions support
|
|
||||||
- Image generation
|
|
||||||
- User/Groups management, RBAC, SSO, etc.
|
|
||||||
|
|
||||||
Onyx can be deployed for single users or large organizations.
|
|
||||||
|
|
||||||
## Install Onyx
|
|
||||||
|
|
||||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
|
||||||
|
|
||||||
<Info>
|
|
||||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
|
||||||
</Info>
|
|
||||||
|
|
||||||
## Usage with Ollama
|
|
||||||
|
|
||||||
1. Login to your Onyx deployment (create an account first).
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/onyx-login.png"
|
|
||||||
alt="Onyx Login Page"
|
|
||||||
width="75%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
2. In the set-up process select `Ollama` as the LLM provider.
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/onyx-ollama-llm.png"
|
|
||||||
alt="Onyx Set Up Form"
|
|
||||||
width="75%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
3. Provide your **Ollama API URL** and select your models.
|
|
||||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/onyx-ollama-form.png"
|
|
||||||
alt="Selecting Ollama Models"
|
|
||||||
width="75%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
|
||||||
|
|
||||||
## Send your first query
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/onyx-query.png"
|
|
||||||
alt="Onyx Query Example"
|
|
||||||
width="75%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Linux
|
title: "Linux"
|
||||||
---
|
---
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
## Manual install
|
## Manual install
|
||||||
|
|
||||||
<Note>
|
<Note>
|
||||||
If you are upgrading from a prior version, you should remove the old libraries
|
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||||
with `sudo rm -rf /usr/lib/ollama` first.
|
|
||||||
</Note>
|
</Note>
|
||||||
|
|
||||||
Download and extract the package:
|
Download and extract the package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||||
| sudo tar x -C /usr
|
| sudo tar zx -C /usr
|
||||||
```
|
```
|
||||||
|
|
||||||
Start Ollama:
|
Start Ollama:
|
||||||
@@ -41,8 +40,8 @@ ollama -v
|
|||||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||||
| sudo tar x -C /usr
|
| sudo tar zx -C /usr
|
||||||
```
|
```
|
||||||
|
|
||||||
### ARM64 install
|
### ARM64 install
|
||||||
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
|||||||
Download and extract the ARM64-specific package:
|
Download and extract the ARM64-specific package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||||
| sudo tar x -C /usr
|
| sudo tar zx -C /usr
|
||||||
```
|
```
|
||||||
|
|
||||||
### Adding Ollama as a startup service (recommended)
|
### Adding Ollama as a startup service (recommended)
|
||||||
@@ -113,11 +112,7 @@ sudo systemctl status ollama
|
|||||||
```
|
```
|
||||||
|
|
||||||
<Note>
|
<Note>
|
||||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||||
kernel source, the version is older and may not support all ROCm features. We
|
|
||||||
recommend you install the latest driver from
|
|
||||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
|
||||||
GPU.
|
|
||||||
</Note>
|
</Note>
|
||||||
|
|
||||||
## Customizing
|
## Customizing
|
||||||
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
Or by re-downloading Ollama:
|
Or by re-downloading Ollama:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||||
| sudo tar x -C /usr
|
| sudo tar zx -C /usr
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installing specific versions
|
## Installing specific versions
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
|||||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
@@ -1464,11 +1464,6 @@ type CompletionRequest struct {
|
|||||||
|
|
||||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||||
TopLogprobs int
|
TopLogprobs int
|
||||||
|
|
||||||
// Image generation fields
|
|
||||||
Width int32 `json:"width,omitempty"`
|
|
||||||
Height int32 `json:"height,omitempty"`
|
|
||||||
Seed int64 `json:"seed,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoneReason represents the reason why a completion response is done
|
// DoneReason represents the reason why a completion response is done
|
||||||
@@ -1517,11 +1512,6 @@ type CompletionResponse struct {
|
|||||||
|
|
||||||
// Logprobs contains log probability information if requested
|
// Logprobs contains log probability information if requested
|
||||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
// Image generation fields
|
|
||||||
Image []byte `json:"image,omitempty"` // Generated image
|
|
||||||
Step int `json:"step,omitempty"` // Current generation step
|
|
||||||
Total int `json:"total,omitempty"` // Total generation steps
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
@@ -442,7 +441,6 @@ type ResponsesWriter struct {
|
|||||||
stream bool
|
stream bool
|
||||||
responseID string
|
responseID string
|
||||||
itemID string
|
itemID string
|
||||||
request openai.ResponsesRequest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||||
@@ -480,9 +478,7 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
|||||||
|
|
||||||
// Non-streaming response
|
// Non-streaming response
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||||
completedAt := time.Now().Unix()
|
|
||||||
response.CompletedAt = &completedAt
|
|
||||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,12 +523,11 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
w := &ResponsesWriter{
|
w := &ResponsesWriter{
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||||
model: req.Model,
|
model: req.Model,
|
||||||
stream: streamRequested,
|
stream: streamRequested,
|
||||||
responseID: responseID,
|
responseID: responseID,
|
||||||
itemID: itemID,
|
itemID: itemID,
|
||||||
request: req,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers based on streaming mode
|
// Set headers based on streaming mode
|
||||||
|
|||||||
@@ -630,10 +630,6 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
|||||||
|
|
||||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||||
func decodeImageURL(url string) (api.ImageData, error) {
|
func decodeImageURL(url string) (api.ImageData, error) {
|
||||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
|
||||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
|
||||||
}
|
|
||||||
|
|
||||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||||
|
|
||||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
@@ -266,9 +265,9 @@ type ResponsesText struct {
|
|||||||
type ResponsesTool struct {
|
type ResponsesTool struct {
|
||||||
Type string `json:"type"` // "function"
|
Type string `json:"type"` // "function"
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description *string `json:"description"` // nullable but required
|
Description string `json:"description,omitempty"`
|
||||||
Strict *bool `json:"strict"` // nullable but required
|
Strict bool `json:"strict,omitempty"`
|
||||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
Parameters map[string]any `json:"parameters,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesRequest struct {
|
type ResponsesRequest struct {
|
||||||
@@ -476,16 +475,11 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var description string
|
|
||||||
if t.Description != nil {
|
|
||||||
description = *t.Description
|
|
||||||
}
|
|
||||||
|
|
||||||
return api.Tool{
|
return api.Tool{
|
||||||
Type: t.Type,
|
Type: t.Type,
|
||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Description: description,
|
Description: t.Description,
|
||||||
Parameters: params,
|
Parameters: params,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
@@ -522,60 +516,17 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
|||||||
|
|
||||||
// Response types for the Responses API
|
// Response types for the Responses API
|
||||||
|
|
||||||
// ResponsesTextField represents the text output configuration in the response.
|
|
||||||
type ResponsesTextField struct {
|
|
||||||
Format ResponsesTextFormat `json:"format"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
|
||||||
type ResponsesReasoningOutput struct {
|
|
||||||
Effort *string `json:"effort,omitempty"`
|
|
||||||
Summary *string `json:"summary,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResponsesError represents an error in the response.
|
|
||||||
type ResponsesError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
|
||||||
type ResponsesIncompleteDetails struct {
|
|
||||||
Reason string `json:"reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponsesResponse struct {
|
type ResponsesResponse struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
CompletedAt *int64 `json:"completed_at"`
|
Status string `json:"status"`
|
||||||
Status string `json:"status"`
|
Model string `json:"model"`
|
||||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
Output []ResponsesOutputItem `json:"output"`
|
||||||
Model string `json:"model"`
|
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||||
PreviousResponseID *string `json:"previous_response_id"`
|
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||||
Instructions *string `json:"instructions"`
|
// requires additional plumbing to find the effective values since the
|
||||||
Output []ResponsesOutputItem `json:"output"`
|
// defaults can come from the model or the request
|
||||||
Error *ResponsesError `json:"error"`
|
|
||||||
Tools []ResponsesTool `json:"tools"`
|
|
||||||
ToolChoice any `json:"tool_choice"`
|
|
||||||
Truncation string `json:"truncation"`
|
|
||||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
|
||||||
Text ResponsesTextField `json:"text"`
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
PresencePenalty float64 `json:"presence_penalty"`
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
||||||
TopLogprobs int `json:"top_logprobs"`
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
|
||||||
Usage *ResponsesUsage `json:"usage"`
|
|
||||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
|
||||||
MaxToolCalls *int `json:"max_tool_calls"`
|
|
||||||
Store bool `json:"store"`
|
|
||||||
Background bool `json:"background"`
|
|
||||||
ServiceTier string `json:"service_tier"`
|
|
||||||
Metadata map[string]any `json:"metadata"`
|
|
||||||
SafetyIdentifier *string `json:"safety_identifier"`
|
|
||||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesOutputItem struct {
|
type ResponsesOutputItem struct {
|
||||||
@@ -599,39 +550,18 @@ type ResponsesReasoningSummary struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesOutputContent struct {
|
type ResponsesOutputContent struct {
|
||||||
Type string `json:"type"` // "output_text"
|
Type string `json:"type"` // "output_text"
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
Annotations []any `json:"annotations"`
|
|
||||||
Logprobs []any `json:"logprobs"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponsesInputTokensDetails struct {
|
|
||||||
CachedTokens int `json:"cached_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponsesOutputTokensDetails struct {
|
|
||||||
ReasoningTokens int `json:"reasoning_tokens"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesUsage struct {
|
type ResponsesUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
|
||||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||||
func derefFloat64(p *float64, def float64) float64 {
|
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||||
if p != nil {
|
|
||||||
return *p
|
|
||||||
}
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
|
||||||
// The request is used to echo back request parameters in the response.
|
|
||||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
|
||||||
var output []ResponsesOutputItem
|
var output []ResponsesOutputItem
|
||||||
|
|
||||||
// Add reasoning item if thinking is present
|
// Add reasoning item if thinking is present
|
||||||
@@ -655,7 +585,6 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
|
|||||||
output = append(output, ResponsesOutputItem{
|
output = append(output, ResponsesOutputItem{
|
||||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||||
Type: "function_call",
|
Type: "function_call",
|
||||||
Status: "completed",
|
|
||||||
CallID: tc.ID,
|
CallID: tc.ID,
|
||||||
Name: tc.Function.Name,
|
Name: tc.Function.Name,
|
||||||
Arguments: tc.Function.Arguments,
|
Arguments: tc.Function.Arguments,
|
||||||
@@ -669,90 +598,25 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
|
|||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []ResponsesOutputContent{
|
Content: []ResponsesOutputContent{
|
||||||
{
|
{
|
||||||
Type: "output_text",
|
Type: "output_text",
|
||||||
Text: chatResponse.Message.Content,
|
Text: chatResponse.Message.Content,
|
||||||
Annotations: []any{},
|
|
||||||
Logprobs: []any{},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var instructions *string
|
|
||||||
if request.Instructions != "" {
|
|
||||||
instructions = &request.Instructions
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build truncation with default
|
|
||||||
truncation := "disabled"
|
|
||||||
if request.Truncation != nil {
|
|
||||||
truncation = *request.Truncation
|
|
||||||
}
|
|
||||||
|
|
||||||
tools := request.Tools
|
|
||||||
if tools == nil {
|
|
||||||
tools = []ResponsesTool{}
|
|
||||||
}
|
|
||||||
|
|
||||||
text := ResponsesTextField{
|
|
||||||
Format: ResponsesTextFormat{Type: "text"},
|
|
||||||
}
|
|
||||||
if request.Text != nil && request.Text.Format != nil {
|
|
||||||
text.Format = *request.Text.Format
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build reasoning output from request
|
|
||||||
var reasoning *ResponsesReasoningOutput
|
|
||||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
|
||||||
reasoning = &ResponsesReasoningOutput{}
|
|
||||||
if request.Reasoning.Effort != "" {
|
|
||||||
reasoning.Effort = &request.Reasoning.Effort
|
|
||||||
}
|
|
||||||
if request.Reasoning.Summary != "" {
|
|
||||||
reasoning.Summary = &request.Reasoning.Summary
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ResponsesResponse{
|
return ResponsesResponse{
|
||||||
ID: responseID,
|
ID: responseID,
|
||||||
Object: "response",
|
Object: "response",
|
||||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||||
CompletedAt: nil, // Set by middleware when writing final response
|
Status: "completed",
|
||||||
Status: "completed",
|
Model: model,
|
||||||
IncompleteDetails: nil, // Only populated if response incomplete
|
Output: output,
|
||||||
Model: model,
|
|
||||||
PreviousResponseID: nil, // Not supported
|
|
||||||
Instructions: instructions,
|
|
||||||
Output: output,
|
|
||||||
Error: nil, // Only populated on failure
|
|
||||||
Tools: tools,
|
|
||||||
ToolChoice: "auto", // Default value
|
|
||||||
Truncation: truncation,
|
|
||||||
ParallelToolCalls: true, // Default value
|
|
||||||
Text: text,
|
|
||||||
TopP: derefFloat64(request.TopP, 1.0),
|
|
||||||
PresencePenalty: 0, // Default value
|
|
||||||
FrequencyPenalty: 0, // Default value
|
|
||||||
TopLogprobs: 0, // Default value
|
|
||||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
|
||||||
Reasoning: reasoning,
|
|
||||||
Usage: &ResponsesUsage{
|
Usage: &ResponsesUsage{
|
||||||
InputTokens: chatResponse.PromptEvalCount,
|
InputTokens: chatResponse.PromptEvalCount,
|
||||||
OutputTokens: chatResponse.EvalCount,
|
OutputTokens: chatResponse.EvalCount,
|
||||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||||
// TODO(drifkin): wire through the actual values
|
|
||||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
|
||||||
// TODO(drifkin): wire through the actual values
|
|
||||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
|
||||||
},
|
},
|
||||||
MaxOutputTokens: request.MaxOutputTokens,
|
|
||||||
MaxToolCalls: nil, // Not supported
|
|
||||||
Store: false, // We don't store responses
|
|
||||||
Background: request.Background,
|
|
||||||
ServiceTier: "default", // Default value
|
|
||||||
Metadata: map[string]any{},
|
|
||||||
SafetyIdentifier: nil, // Not supported
|
|
||||||
PromptCacheKey: nil, // Not supported
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -772,7 +636,6 @@ type ResponsesStreamConverter struct {
|
|||||||
responseID string
|
responseID string
|
||||||
itemID string
|
itemID string
|
||||||
model string
|
model string
|
||||||
request ResponsesRequest
|
|
||||||
|
|
||||||
// State tracking (mutated across Process calls)
|
// State tracking (mutated across Process calls)
|
||||||
firstWrite bool
|
firstWrite bool
|
||||||
@@ -805,12 +668,11 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||||
return &ResponsesStreamConverter{
|
return &ResponsesStreamConverter{
|
||||||
responseID: responseID,
|
responseID: responseID,
|
||||||
itemID: itemID,
|
itemID: itemID,
|
||||||
model: model,
|
model: model,
|
||||||
request: request,
|
|
||||||
firstWrite: true,
|
firstWrite: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -855,120 +717,25 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
|||||||
return events
|
return events
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
|
||||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
|
||||||
var instructions any = nil
|
|
||||||
if c.request.Instructions != "" {
|
|
||||||
instructions = c.request.Instructions
|
|
||||||
}
|
|
||||||
|
|
||||||
truncation := "disabled"
|
|
||||||
if c.request.Truncation != nil {
|
|
||||||
truncation = *c.request.Truncation
|
|
||||||
}
|
|
||||||
|
|
||||||
var tools []any
|
|
||||||
if c.request.Tools != nil {
|
|
||||||
for _, t := range c.request.Tools {
|
|
||||||
tools = append(tools, map[string]any{
|
|
||||||
"type": t.Type,
|
|
||||||
"name": t.Name,
|
|
||||||
"description": t.Description,
|
|
||||||
"strict": t.Strict,
|
|
||||||
"parameters": t.Parameters,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if tools == nil {
|
|
||||||
tools = []any{}
|
|
||||||
}
|
|
||||||
|
|
||||||
textFormat := map[string]any{"type": "text"}
|
|
||||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
|
||||||
textFormat = map[string]any{
|
|
||||||
"type": c.request.Text.Format.Type,
|
|
||||||
}
|
|
||||||
if c.request.Text.Format.Name != "" {
|
|
||||||
textFormat["name"] = c.request.Text.Format.Name
|
|
||||||
}
|
|
||||||
if c.request.Text.Format.Schema != nil {
|
|
||||||
textFormat["schema"] = c.request.Text.Format.Schema
|
|
||||||
}
|
|
||||||
if c.request.Text.Format.Strict != nil {
|
|
||||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var reasoning any = nil
|
|
||||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
|
||||||
r := map[string]any{}
|
|
||||||
if c.request.Reasoning.Effort != "" {
|
|
||||||
r["effort"] = c.request.Reasoning.Effort
|
|
||||||
} else {
|
|
||||||
r["effort"] = nil
|
|
||||||
}
|
|
||||||
if c.request.Reasoning.Summary != "" {
|
|
||||||
r["summary"] = c.request.Reasoning.Summary
|
|
||||||
} else {
|
|
||||||
r["summary"] = nil
|
|
||||||
}
|
|
||||||
reasoning = r
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build top_p and temperature with defaults
|
|
||||||
topP := 1.0
|
|
||||||
if c.request.TopP != nil {
|
|
||||||
topP = *c.request.TopP
|
|
||||||
}
|
|
||||||
temperature := 1.0
|
|
||||||
if c.request.Temperature != nil {
|
|
||||||
temperature = *c.request.Temperature
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[string]any{
|
|
||||||
"id": c.responseID,
|
|
||||||
"object": "response",
|
|
||||||
"created_at": time.Now().Unix(),
|
|
||||||
"completed_at": nil,
|
|
||||||
"status": status,
|
|
||||||
"incomplete_details": nil,
|
|
||||||
"model": c.model,
|
|
||||||
"previous_response_id": nil,
|
|
||||||
"instructions": instructions,
|
|
||||||
"output": output,
|
|
||||||
"error": nil,
|
|
||||||
"tools": tools,
|
|
||||||
"tool_choice": "auto",
|
|
||||||
"truncation": truncation,
|
|
||||||
"parallel_tool_calls": true,
|
|
||||||
"text": map[string]any{"format": textFormat},
|
|
||||||
"top_p": topP,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"top_logprobs": 0,
|
|
||||||
"temperature": temperature,
|
|
||||||
"reasoning": reasoning,
|
|
||||||
"usage": usage,
|
|
||||||
"max_output_tokens": c.request.MaxOutputTokens,
|
|
||||||
"max_tool_calls": nil,
|
|
||||||
"store": false,
|
|
||||||
"background": c.request.Background,
|
|
||||||
"service_tier": "default",
|
|
||||||
"metadata": map[string]any{},
|
|
||||||
"safety_identifier": nil,
|
|
||||||
"prompt_cache_key": nil,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||||
return c.newEvent("response.created", map[string]any{
|
return c.newEvent("response.created", map[string]any{
|
||||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
"response": map[string]any{
|
||||||
|
"id": c.responseID,
|
||||||
|
"object": "response",
|
||||||
|
"status": "in_progress",
|
||||||
|
"output": []any{},
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||||
return c.newEvent("response.in_progress", map[string]any{
|
return c.newEvent("response.in_progress", map[string]any{
|
||||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
"response": map[string]any{
|
||||||
|
"id": c.responseID,
|
||||||
|
"object": "response",
|
||||||
|
"status": "in_progress",
|
||||||
|
"output": []any{},
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -995,10 +762,9 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
|||||||
|
|
||||||
// Emit delta
|
// Emit delta
|
||||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||||
"item_id": c.reasoningItemID,
|
"item_id": c.reasoningItemID,
|
||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"summary_index": 0,
|
"delta": thinking,
|
||||||
"delta": thinking,
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// TODO(drifkin): consider adding
|
// TODO(drifkin): consider adding
|
||||||
@@ -1017,10 +783,9 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
|||||||
|
|
||||||
events := []ResponsesStreamEvent{
|
events := []ResponsesStreamEvent{
|
||||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||||
"item_id": c.reasoningItemID,
|
"item_id": c.reasoningItemID,
|
||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"summary_index": 0,
|
"text": c.accumulatedThinking,
|
||||||
"text": c.accumulatedThinking,
|
|
||||||
}),
|
}),
|
||||||
c.newEvent("response.output_item.done", map[string]any{
|
c.newEvent("response.output_item.done", map[string]any{
|
||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
@@ -1133,10 +898,8 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
|||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"content_index": c.contentIndex,
|
"content_index": c.contentIndex,
|
||||||
"part": map[string]any{
|
"part": map[string]any{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": "",
|
"text": "",
|
||||||
"annotations": []any{},
|
|
||||||
"logprobs": []any{},
|
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@@ -1150,7 +913,6 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
|||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"content_index": 0,
|
"content_index": 0,
|
||||||
"delta": content,
|
"delta": content,
|
||||||
"logprobs": []any{},
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return events
|
return events
|
||||||
@@ -1182,10 +944,8 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
|||||||
"status": "completed",
|
"status": "completed",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": []map[string]any{{
|
"content": []map[string]any{{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": c.accumulatedText,
|
"text": c.accumulatedText,
|
||||||
"annotations": []any{},
|
|
||||||
"logprobs": []any{},
|
|
||||||
}},
|
}},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1207,7 +967,6 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
|||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"content_index": 0,
|
"content_index": 0,
|
||||||
"text": c.accumulatedText,
|
"text": c.accumulatedText,
|
||||||
"logprobs": []any{},
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// response.content_part.done
|
// response.content_part.done
|
||||||
@@ -1216,10 +975,8 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
|||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"content_index": 0,
|
"content_index": 0,
|
||||||
"part": map[string]any{
|
"part": map[string]any{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": c.accumulatedText,
|
"text": c.accumulatedText,
|
||||||
"annotations": []any{},
|
|
||||||
"logprobs": []any{},
|
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -1232,31 +989,26 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
|||||||
"status": "completed",
|
"status": "completed",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": []map[string]any{{
|
"content": []map[string]any{{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": c.accumulatedText,
|
"text": c.accumulatedText,
|
||||||
"annotations": []any{},
|
|
||||||
"logprobs": []any{},
|
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// response.completed
|
// response.completed
|
||||||
usage := map[string]any{
|
|
||||||
"input_tokens": r.PromptEvalCount,
|
|
||||||
"output_tokens": r.EvalCount,
|
|
||||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
|
||||||
"input_tokens_details": map[string]any{
|
|
||||||
"cached_tokens": 0,
|
|
||||||
},
|
|
||||||
"output_tokens_details": map[string]any{
|
|
||||||
"reasoning_tokens": 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
|
||||||
response["completed_at"] = time.Now().Unix()
|
|
||||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||||
"response": response,
|
"response": map[string]any{
|
||||||
|
"id": c.responseID,
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": c.buildFinalOutput(),
|
||||||
|
"usage": map[string]any{
|
||||||
|
"input_tokens": r.PromptEvalCount,
|
||||||
|
"output_tokens": r.EvalCount,
|
||||||
|
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||||
|
},
|
||||||
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
|||||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
// First chunk with content
|
// First chunk with content
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
Message: api.Message{
|
Message: api.Message{
|
||||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
// First chunk with thinking
|
// First chunk with thinking
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
|||||||
Content: "The answer is 42",
|
Content: "The answer is 42",
|
||||||
},
|
},
|
||||||
Done: true,
|
Done: true,
|
||||||
}, ResponsesRequest{})
|
})
|
||||||
|
|
||||||
// Should have 2 output items: reasoning + message
|
// Should have 2 output items: reasoning + message
|
||||||
if len(response.Output) != 2 {
|
if len(response.Output) != 2 {
|
||||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||||
// Verify that response.output_item.done includes content field for messages
|
// Verify that response.output_item.done includes content field for messages
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
// First chunk
|
// First chunk
|
||||||
converter.Process(api.ChatResponse{
|
converter.Process(api.ChatResponse{
|
||||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||||
// Verify that response.completed includes the output array
|
// Verify that response.completed includes the output array
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
// Process some content
|
// Process some content
|
||||||
converter.Process(api.ChatResponse{
|
converter.Process(api.ChatResponse{
|
||||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||||
// Verify that response.created includes an empty output array
|
// Verify that response.created includes an empty output array
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
Message: api.Message{Content: "Hi"},
|
Message: api.Message{Content: "Hi"},
|
||||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||||
// Verify that events include incrementing sequence numbers
|
// Verify that events include incrementing sequence numbers
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
Message: api.Message{Content: "Hello"},
|
Message: api.Message{Content: "Hello"},
|
||||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||||
// Verify that function call items include status field
|
// Verify that function call items include status field
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
Message: api.Message{
|
Message: api.Message{
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Prompt struct {
|
type Prompt struct {
|
||||||
@@ -37,11 +36,10 @@ type Terminal struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Instance struct {
|
type Instance struct {
|
||||||
Prompt *Prompt
|
Prompt *Prompt
|
||||||
Terminal *Terminal
|
Terminal *Terminal
|
||||||
History *History
|
History *History
|
||||||
Pasting bool
|
Pasting bool
|
||||||
pastedLines []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(prompt Prompt) (*Instance, error) {
|
func New(prompt Prompt) (*Instance, error) {
|
||||||
@@ -176,8 +174,6 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
case CharEsc:
|
case CharEsc:
|
||||||
esc = true
|
esc = true
|
||||||
case CharInterrupt:
|
case CharInterrupt:
|
||||||
i.pastedLines = nil
|
|
||||||
i.Prompt.UseAlt = false
|
|
||||||
return "", ErrInterrupt
|
return "", ErrInterrupt
|
||||||
case CharPrev:
|
case CharPrev:
|
||||||
i.historyPrev(buf, ¤tLineBuf)
|
i.historyPrev(buf, ¤tLineBuf)
|
||||||
@@ -192,23 +188,7 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
case CharForward:
|
case CharForward:
|
||||||
buf.MoveRight()
|
buf.MoveRight()
|
||||||
case CharBackspace, CharCtrlH:
|
case CharBackspace, CharCtrlH:
|
||||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
buf.Remove()
|
||||||
lastIdx := len(i.pastedLines) - 1
|
|
||||||
prevLine := i.pastedLines[lastIdx]
|
|
||||||
i.pastedLines = i.pastedLines[:lastIdx]
|
|
||||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
|
||||||
if len(i.pastedLines) == 0 {
|
|
||||||
fmt.Print(i.Prompt.Prompt)
|
|
||||||
i.Prompt.UseAlt = false
|
|
||||||
} else {
|
|
||||||
fmt.Print(i.Prompt.AltPrompt)
|
|
||||||
}
|
|
||||||
for _, r := range prevLine {
|
|
||||||
buf.Add(r)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
buf.Remove()
|
|
||||||
}
|
|
||||||
case CharTab:
|
case CharTab:
|
||||||
// todo: convert back to real tabs
|
// todo: convert back to real tabs
|
||||||
for range 8 {
|
for range 8 {
|
||||||
@@ -231,28 +211,13 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
case CharCtrlZ:
|
case CharCtrlZ:
|
||||||
fd := os.Stdin.Fd()
|
fd := os.Stdin.Fd()
|
||||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||||
case CharCtrlJ:
|
case CharEnter, CharCtrlJ:
|
||||||
i.pastedLines = append(i.pastedLines, buf.String())
|
|
||||||
buf.Buf.Clear()
|
|
||||||
buf.Pos = 0
|
|
||||||
buf.DisplayPos = 0
|
|
||||||
buf.LineHasSpace.Clear()
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Print(i.Prompt.AltPrompt)
|
|
||||||
i.Prompt.UseAlt = true
|
|
||||||
continue
|
|
||||||
case CharEnter:
|
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
if len(i.pastedLines) > 0 {
|
|
||||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
|
||||||
i.pastedLines = nil
|
|
||||||
}
|
|
||||||
if output != "" {
|
if output != "" {
|
||||||
i.History.Add(output)
|
i.History.Add(output)
|
||||||
}
|
}
|
||||||
buf.MoveToEnd()
|
buf.MoveToEnd()
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
i.Prompt.UseAlt = false
|
|
||||||
|
|
||||||
return output, nil
|
return output, nil
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ _build_macapp() {
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
rm -f dist/Ollama-darwin.zip
|
rm -f dist/Ollama-darwin.zip
|
||||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||||
|
|
||||||
# Notarize and Staple
|
# Notarize and Staple
|
||||||
@@ -187,7 +187,7 @@ _build_macapp() {
|
|||||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||||
rm -f dist/Ollama-darwin.zip
|
rm -f dist/Ollama-darwin.zip
|
||||||
$(xcrun -f stapler) staple dist/Ollama.app
|
$(xcrun -f stapler) staple dist/Ollama.app
|
||||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||||
|
|
||||||
rm -f dist/Ollama.dmg
|
rm -f dist/Ollama.dmg
|
||||||
|
|
||||||
|
|||||||
@@ -50,17 +50,12 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
|||||||
return redirectURL, nil
|
return redirectURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||||
redirectURL, err := challenge.URL()
|
redirectURL, err := challenge.URL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
|
||||||
if redirectURL.Host != originalHost {
|
|
||||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
|
||||||
}
|
|
||||||
|
|
||||||
sha256sum := sha256.Sum256(nil)
|
sha256sum := sha256.Sum256(nil)
|
||||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||||
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
realm string
|
|
||||||
originalHost string
|
|
||||||
wantMismatch bool
|
|
||||||
}{
|
|
||||||
{"https://example.com/token", "example.com", false},
|
|
||||||
{"https://example.com/token", "other.com", true},
|
|
||||||
{"https://example.com/token", "localhost:8000", true},
|
|
||||||
{"https://localhost:5000/token", "localhost:5000", false},
|
|
||||||
{"https://localhost:5000/token", "localhost:6000", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.originalHost, func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
|
||||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
|
||||||
|
|
||||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
|
||||||
if tt.wantMismatch && !isMismatch {
|
|
||||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
|
||||||
}
|
|
||||||
if !tt.wantMismatch && isMismatch {
|
|
||||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRegistryChallenge(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
wantRealm, wantService, wantScope string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
|
||||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
|
||||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
|
||||||
},
|
|
||||||
{"", "", "", ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
result := parseRegistryChallenge(tt.input)
|
|
||||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
|
||||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
|
||||||
tt.input, result.Realm, result.Service, result.Scope,
|
|
||||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryChallengeURL(t *testing.T) {
|
|
||||||
challenge := registryChallenge{
|
|
||||||
Realm: "https://auth.example.com/token",
|
|
||||||
Service: "registry",
|
|
||||||
Scope: "repo:foo:pull repo:bar:push",
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := challenge.URL()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("URL() error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.Host != "auth.example.com" {
|
|
||||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
|
||||||
}
|
|
||||||
if u.Path != "/token" {
|
|
||||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
|
||||||
}
|
|
||||||
|
|
||||||
q := u.Query()
|
|
||||||
if q.Get("service") != "registry" {
|
|
||||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
|
||||||
}
|
|
||||||
if scopes := q["scope"]; len(scopes) != 2 {
|
|
||||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
|
||||||
}
|
|
||||||
if q.Get("ts") == "" {
|
|
||||||
t.Error("missing ts")
|
|
||||||
}
|
|
||||||
if q.Get("nonce") == "" {
|
|
||||||
t.Error("missing nonce")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Nonces should differ between calls
|
|
||||||
u2, _ := challenge.URL()
|
|
||||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
|
||||||
t.Error("nonce should be unique per call")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
|
||||||
challenge := registryChallenge{Realm: "://invalid"}
|
|
||||||
if _, err := challenge.URL(); err == nil {
|
|
||||||
t.Error("expected error for invalid URL")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -95,11 +95,48 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
numDownloadParts = 16
|
// numDownloadParts is the default number of concurrent download parts for standard downloads
|
||||||
|
numDownloadParts = 16
|
||||||
|
// numHFDownloadParts is the reduced number of concurrent download parts for HuggingFace
|
||||||
|
// downloads to avoid triggering rate limits (HTTP 429 errors). See GitHub issue #13297.
|
||||||
|
numHFDownloadParts = 4
|
||||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isHuggingFaceURL returns true if the URL is from a HuggingFace domain.
|
||||||
|
// This includes:
|
||||||
|
// - huggingface.co (main domain)
|
||||||
|
// - *.huggingface.co (subdomains like cdn-lfs.huggingface.co)
|
||||||
|
// - hf.co (shortlink domain)
|
||||||
|
// - *.hf.co (CDN domains like cdn-lfs.hf.co, cdn-lfs3.hf.co)
|
||||||
|
func isHuggingFaceURL(u *url.URL) bool {
|
||||||
|
if u == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
host := strings.ToLower(u.Hostname())
|
||||||
|
return host == "huggingface.co" ||
|
||||||
|
strings.HasSuffix(host, ".huggingface.co") ||
|
||||||
|
host == "hf.co" ||
|
||||||
|
strings.HasSuffix(host, ".hf.co")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNumDownloadParts returns the number of concurrent download parts to use
|
||||||
|
// for the given URL. HuggingFace URLs use reduced concurrency (default 4) to
|
||||||
|
// avoid triggering rate limits. This can be overridden via the OLLAMA_HF_CONCURRENCY
|
||||||
|
// environment variable. For non-HuggingFace URLs, returns the standard concurrency (16).
|
||||||
|
func getNumDownloadParts(u *url.URL) int {
|
||||||
|
if isHuggingFaceURL(u) {
|
||||||
|
if v := os.Getenv("OLLAMA_HF_CONCURRENCY"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return numHFDownloadParts
|
||||||
|
}
|
||||||
|
return numDownloadParts
|
||||||
|
}
|
||||||
|
|
||||||
func (p *blobDownloadPart) Name() string {
|
func (p *blobDownloadPart) Name() string {
|
||||||
return strings.Join([]string{
|
return strings.Join([]string{
|
||||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||||
@@ -271,7 +308,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
}
|
}
|
||||||
|
|
||||||
g, inner := errgroup.WithContext(ctx)
|
g, inner := errgroup.WithContext(ctx)
|
||||||
g.SetLimit(numDownloadParts)
|
concurrency := getNumDownloadParts(directURL)
|
||||||
|
if concurrency != numDownloadParts {
|
||||||
|
slog.Info(fmt.Sprintf("using reduced concurrency (%d) for HuggingFace download", concurrency))
|
||||||
|
}
|
||||||
|
g.SetLimit(concurrency)
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
part := b.Parts[i]
|
part := b.Parts[i]
|
||||||
if part.Completed.Load() == part.Size {
|
if part.Completed.Load() == part.Size {
|
||||||
|
|||||||
194
server/download_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsHuggingFaceURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil url",
|
||||||
|
url: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface.co main domain",
|
||||||
|
url: "https://huggingface.co/some/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cdn-lfs.huggingface.co subdomain",
|
||||||
|
url: "https://cdn-lfs.huggingface.co/repos/abc/123",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cdn-lfs3.hf.co CDN domain",
|
||||||
|
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hf.co shortlink domain",
|
||||||
|
url: "https://hf.co/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase HuggingFace domain",
|
||||||
|
url: "https://HUGGINGFACE.CO/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case HF domain",
|
||||||
|
url: "https://Cdn-Lfs.HF.Co/repos",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ollama registry",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "github.com",
|
||||||
|
url: "https://github.com/ollama/ollama",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fake huggingface domain",
|
||||||
|
url: "https://nothuggingface.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fake hf domain",
|
||||||
|
url: "https://nothf.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface in path not host",
|
||||||
|
url: "https://example.com/huggingface.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var u *url.URL
|
||||||
|
if tc.url != "" {
|
||||||
|
var err error
|
||||||
|
u, err = url.Parse(tc.url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
got := isHuggingFaceURL(u)
|
||||||
|
assert.Equal(t, tc.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNumDownloadParts(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
envValue string
|
||||||
|
expected int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil url returns default",
|
||||||
|
url: "",
|
||||||
|
envValue: "",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "nil URL should return standard concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ollama registry returns default",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
envValue: "",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "Ollama registry should use standard concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface returns reduced default",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "HuggingFace should use reduced concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hf.co CDN returns reduced default",
|
||||||
|
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
||||||
|
envValue: "",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "HuggingFace CDN should use reduced concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with env override",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "2",
|
||||||
|
expected: 2,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY should override default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with higher env override",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "8",
|
||||||
|
expected: 8,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY can be set higher than default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (non-numeric)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "invalid",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Invalid OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (zero)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "0",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Zero OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (negative)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "-1",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Negative OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-huggingface ignores env",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
envValue: "2",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY should not affect non-HF URLs",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Set or clear the environment variable
|
||||||
|
if tc.envValue != "" {
|
||||||
|
t.Setenv("OLLAMA_HF_CONCURRENCY", tc.envValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
var u *url.URL
|
||||||
|
if tc.url != "" {
|
||||||
|
var err error
|
||||||
|
u, err = url.Parse(tc.url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
got := getNumDownloadParts(u)
|
||||||
|
assert.Equal(t, tc.expected, got, tc.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
Realm: challenge.Realm,
|
Realm: challenge.Realm,
|
||||||
Service: challenge.Service,
|
Service: challenge.Service,
|
||||||
Scope: challenge.Scope,
|
Scope: challenge.Scope,
|
||||||
}, base.Host)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||||
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
Realm: challenge.Realm,
|
Realm: challenge.Realm,
|
||||||
Service: challenge.Service,
|
Service: challenge.Service,
|
||||||
Scope: challenge.Scope,
|
Scope: challenge.Scope,
|
||||||
}, base.Host)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||||
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
|
|
||||||
// Handle authentication error with one retry
|
// Handle authentication error with one retry
|
||||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
token, err := getAuthorizationToken(ctx, challenge)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ import (
|
|||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||||
@@ -163,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
|||||||
return runner.llama, model, &opts, nil
|
return runner.llama, model, &opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||||
|
// This implements the imagegenapi.RunnerScheduler interface.
|
||||||
|
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||||
|
m := &Model{
|
||||||
|
Name: modelName,
|
||||||
|
ShortName: modelName,
|
||||||
|
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
Capabilities: []string{"image"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||||
|
var runner *runnerRef
|
||||||
|
select {
|
||||||
|
case runner = <-runnerCh:
|
||||||
|
case err := <-errCh:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return runner.llama, nil
|
||||||
|
}
|
||||||
|
|
||||||
func signinURL() (string, error) {
|
func signinURL() (string, error) {
|
||||||
pubKey, err := auth.GetPublicKey()
|
pubKey, err := auth.GetPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -190,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is a known image generation model
|
||||||
|
if imagegen.ResolveModelName(req.Model) != "" {
|
||||||
|
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
name := model.ParseName(req.Model)
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
// Ideally this is "invalid model name" but we're keeping with
|
// Ideally this is "invalid model name" but we're keeping with
|
||||||
@@ -1557,12 +1587,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||||
// Experimental OpenAI-compatible image generation endpoint
|
|
||||||
r.POST("/v1/images/generations", s.handleImageGeneration)
|
|
||||||
|
|
||||||
// Inference (Anthropic compatibility)
|
// Inference (Anthropic compatibility)
|
||||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||||
|
|
||||||
|
// Experimental image generation support
|
||||||
|
imagegenapi.RegisterRoutes(r, s)
|
||||||
|
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
rs := ®istry.Local{
|
rs := ®istry.Local{
|
||||||
@@ -1880,62 +1911,6 @@ func toolCallId() string {
|
|||||||
return "call_" + strings.ToLower(string(b))
|
return "call_" + strings.ToLower(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleImageGeneration(c *gin.Context) {
|
|
||||||
var req struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Size string `json:"size"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := GetModel(req.Model)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
|
|
||||||
var runner *runnerRef
|
|
||||||
select {
|
|
||||||
case runner = <-runnerCh:
|
|
||||||
case err := <-errCh:
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse size (e.g., "1024x768") into width and height
|
|
||||||
width, height := int32(1024), int32(1024)
|
|
||||||
if req.Size != "" {
|
|
||||||
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var image []byte
|
|
||||||
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
||||||
Prompt: req.Prompt,
|
|
||||||
Width: width,
|
|
||||||
Height: height,
|
|
||||||
}, func(resp llm.CompletionResponse) {
|
|
||||||
if len(resp.Image) > 0 {
|
|
||||||
image = resp.Image
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"created": time.Now().Unix(),
|
|
||||||
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@@ -805,8 +807,32 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
|||||||
func (s *mockLlm) HasExited() bool { return false }
|
func (s *mockLlm) HasExited() bool { return false }
|
||||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||||
|
|
||||||
|
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||||
|
// are correctly identified and routed differently from language models.
|
||||||
|
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||||
|
// Model with image capability should be detected
|
||||||
|
imageModel := &Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
Capabilities: []string{"image"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||||
|
|
||||||
|
// Model without image capability should not be detected
|
||||||
|
langModel := &Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
Capabilities: []string{"completion"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||||
|
|
||||||
|
// Empty capabilities should not match
|
||||||
|
emptyModel := &Model{}
|
||||||
|
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||||
|
}
|
||||||
|
|
||||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||||
// loaded in the scheduler can be evicted when idle.
|
// loaded in the scheduler can be evicted by a language model request.
|
||||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
@@ -838,59 +864,3 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
|||||||
require.NotNil(t, runner)
|
require.NotNil(t, runner)
|
||||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
|
||||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
|
||||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
|
||||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
s := InitScheduler(ctx)
|
|
||||||
s.getGpuFn = getGpuFn
|
|
||||||
s.getSystemInfoFn = getSystemInfoFn
|
|
||||||
|
|
||||||
// Load both an imagegen runner and a language model runner
|
|
||||||
imageGenRunner := &runnerRef{
|
|
||||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
|
||||||
modelPath: "/fake/flux/model",
|
|
||||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
|
||||||
sessionDuration: 10 * time.Millisecond,
|
|
||||||
numParallel: 1,
|
|
||||||
refCount: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
langModelRunner := &runnerRef{
|
|
||||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
|
||||||
modelPath: "/fake/llama3/model",
|
|
||||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
|
||||||
sessionDuration: 10 * time.Millisecond,
|
|
||||||
numParallel: 1,
|
|
||||||
refCount: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
|
||||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
|
||||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
|
||||||
s.loadedMu.Unlock()
|
|
||||||
|
|
||||||
// Verify both are loaded
|
|
||||||
s.loadedMu.Lock()
|
|
||||||
require.Len(t, s.loaded, 2)
|
|
||||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
|
||||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
|
||||||
s.loadedMu.Unlock()
|
|
||||||
|
|
||||||
// Verify updateFreeSpace accounts for both
|
|
||||||
gpus := []ml.DeviceInfo{
|
|
||||||
{
|
|
||||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
|
||||||
TotalMemory: 24 * format.GigaByte,
|
|
||||||
FreeMemory: 24 * format.GigaByte,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
s.updateFreeSpace(gpus)
|
|
||||||
|
|
||||||
// Free memory should be reduced by both models
|
|
||||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
|
||||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
|||||||
case resp.StatusCode == http.StatusUnauthorized:
|
case resp.StatusCode == http.StatusUnauthorized:
|
||||||
w.Rollback()
|
w.Rollback()
|
||||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
token, err := getAuthorizationToken(ctx, challenge)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
50
x/README.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Experimental Features
|
||||||
|
|
||||||
|
## MLX Backend
|
||||||
|
|
||||||
|
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||||
|
|
||||||
|
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
|
||||||
|
|
||||||
|
### Building ollama-mlx
|
||||||
|
|
||||||
|
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
|
||||||
|
|
||||||
|
#### macOS (Apple Silicon and Intel)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build MLX backend libraries
|
||||||
|
cmake --preset MLX
|
||||||
|
cmake --build --preset MLX --parallel
|
||||||
|
cmake --install build --component MLX
|
||||||
|
|
||||||
|
# Build ollama-mlx binary
|
||||||
|
go build -tags mlx -o ollama-mlx .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Linux (CUDA)
|
||||||
|
|
||||||
|
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build MLX backend libraries with CUDA support
|
||||||
|
cmake --preset 'MLX CUDA 13'
|
||||||
|
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||||
|
cmake --install build --component MLX
|
||||||
|
|
||||||
|
# Build ollama-mlx binary
|
||||||
|
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
|
||||||
|
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
|
||||||
|
go build -tags mlx -o ollama-mlx .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using build scripts
|
||||||
|
|
||||||
|
The build scripts automatically create the `ollama-mlx` binary:
|
||||||
|
|
||||||
|
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
|
||||||
|
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
|
||||||
|
|
||||||
|
## Image Generation
|
||||||
|
|
||||||
|
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.
|
||||||
67
x/cmd/run.go
@@ -25,6 +25,14 @@ import (
|
|||||||
"github.com/ollama/ollama/x/tools"
|
"github.com/ollama/ollama/x/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MultilineState tracks the state of multiline input
|
||||||
|
type MultilineState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
MultilineNone MultilineState = iota
|
||||||
|
MultilineSystem
|
||||||
|
)
|
||||||
|
|
||||||
// Tool output capping constants
|
// Tool output capping constants
|
||||||
const (
|
const (
|
||||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||||
@@ -648,7 +656,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
Prompt: ">>> ",
|
Prompt: ">>> ",
|
||||||
AltPrompt: "... ",
|
AltPrompt: "... ",
|
||||||
Placeholder: "Send a message (/? for help)",
|
Placeholder: "Send a message (/? for help)",
|
||||||
AltPlaceholder: "Press Enter to send",
|
AltPlaceholder: `Use """ to end multi-line input`,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -699,6 +707,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var format string
|
var format string
|
||||||
var system string
|
var system string
|
||||||
|
var multiline MultilineState = MultilineNone
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := scanner.Readline()
|
line, err := scanner.Readline()
|
||||||
@@ -712,12 +721,37 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
}
|
}
|
||||||
scanner.Prompt.UseAlt = false
|
scanner.Prompt.UseAlt = false
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
|
multiline = MultilineNone
|
||||||
continue
|
continue
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
case multiline != MultilineNone:
|
||||||
|
// check if there's a multiline terminating string
|
||||||
|
before, ok := strings.CutSuffix(line, `"""`)
|
||||||
|
sb.WriteString(before)
|
||||||
|
if !ok {
|
||||||
|
fmt.Fprintln(&sb)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch multiline {
|
||||||
|
case MultilineSystem:
|
||||||
|
system = sb.String()
|
||||||
|
newMessage := api.Message{Role: "system", Content: system}
|
||||||
|
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||||
|
messages[len(messages)-1] = newMessage
|
||||||
|
} else {
|
||||||
|
messages = append(messages, newMessage)
|
||||||
|
}
|
||||||
|
fmt.Println("Set system message.")
|
||||||
|
sb.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
multiline = MultilineNone
|
||||||
|
scanner.Prompt.UseAlt = false
|
||||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||||
return nil
|
return nil
|
||||||
case strings.HasPrefix(line, "/clear"):
|
case strings.HasPrefix(line, "/clear"):
|
||||||
@@ -826,18 +860,41 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
options[args[2]] = fp[args[2]]
|
options[args[2]] = fp[args[2]]
|
||||||
case "system":
|
case "system":
|
||||||
if len(args) < 3 {
|
if len(args) < 3 {
|
||||||
fmt.Println("Usage: /set system <message>")
|
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
system = strings.Join(args[2:], " ")
|
multiline = MultilineSystem
|
||||||
newMessage := api.Message{Role: "system", Content: system}
|
|
||||||
|
line := strings.Join(args[2:], " ")
|
||||||
|
line, ok := strings.CutPrefix(line, `"""`)
|
||||||
|
if !ok {
|
||||||
|
multiline = MultilineNone
|
||||||
|
} else {
|
||||||
|
// only cut suffix if the line is multiline
|
||||||
|
line, ok = strings.CutSuffix(line, `"""`)
|
||||||
|
if ok {
|
||||||
|
multiline = MultilineNone
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(line)
|
||||||
|
if multiline != MultilineNone {
|
||||||
|
scanner.Prompt.UseAlt = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
system = sb.String()
|
||||||
|
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||||
|
// Check if the slice is not empty and the last message is from 'system'
|
||||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||||
|
// Replace the last message
|
||||||
messages[len(messages)-1] = newMessage
|
messages[len(messages)-1] = newMessage
|
||||||
} else {
|
} else {
|
||||||
messages = append(messages, newMessage)
|
messages = append(messages, newMessage)
|
||||||
}
|
}
|
||||||
fmt.Println("Set system message.")
|
fmt.Println("Set system message.")
|
||||||
|
sb.Reset()
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||||
@@ -1024,7 +1081,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
sb.WriteString(line)
|
sb.WriteString(line)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sb.Len() > 0 {
|
if sb.Len() > 0 && multiline == MultilineNone {
|
||||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||||
messages = append(messages, newMessage)
|
messages = append(messages, newMessage)
|
||||||
|
|
||||||
|
|||||||
231
x/imagegen/api/handler.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RunnerScheduler is the interface for scheduling a model runner.
|
||||||
|
// This is implemented by server.Server to avoid circular imports.
|
||||||
|
type RunnerScheduler interface {
|
||||||
|
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes registers the image generation API routes.
|
||||||
|
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||||
|
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||||
|
ImageGenerationHandler(c, scheduler)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||||
|
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||||
|
var req ImageGenerationRequest
|
||||||
|
if err := c.BindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
if req.Model == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Prompt == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults
|
||||||
|
if req.N == 0 {
|
||||||
|
req.N = 1
|
||||||
|
}
|
||||||
|
if req.Size == "" {
|
||||||
|
req.Size = "1024x1024"
|
||||||
|
}
|
||||||
|
if req.ResponseFormat == "" {
|
||||||
|
req.ResponseFormat = "b64_json"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify model exists
|
||||||
|
if imagegen.ResolveModelName(req.Model) == "" {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse size
|
||||||
|
width, height := parseSize(req.Size)
|
||||||
|
|
||||||
|
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||||
|
opts := api.Options{}
|
||||||
|
opts.NumCtx = int(width)
|
||||||
|
opts.NumGPU = int(height)
|
||||||
|
|
||||||
|
// Schedule runner
|
||||||
|
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
status = http.StatusNotFound
|
||||||
|
}
|
||||||
|
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build completion request
|
||||||
|
completionReq := llm.CompletionRequest{
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
Options: &opts,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Stream {
|
||||||
|
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||||
|
} else {
|
||||||
|
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
|
||||||
|
var imageBase64 string
|
||||||
|
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||||
|
if resp.Done {
|
||||||
|
imageBase64 = extractBase64(resp.Content)
|
||||||
|
} else {
|
||||||
|
progress := parseProgress(resp.Content)
|
||||||
|
if progress.Total > 0 {
|
||||||
|
c.SSEvent("progress", progress)
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SSEvent("done", buildResponse(imageBase64, format))
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||||
|
var imageBase64 string
|
||||||
|
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||||
|
if resp.Done {
|
||||||
|
imageBase64 = extractBase64(resp.Content)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSize(size string) (int32, int32) {
|
||||||
|
parts := strings.Split(size, "x")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return 1024, 1024
|
||||||
|
}
|
||||||
|
w, _ := strconv.Atoi(parts[0])
|
||||||
|
h, _ := strconv.Atoi(parts[1])
|
||||||
|
if w == 0 {
|
||||||
|
w = 1024
|
||||||
|
}
|
||||||
|
if h == 0 {
|
||||||
|
h = 1024
|
||||||
|
}
|
||||||
|
return int32(w), int32(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBase64(content string) string {
|
||||||
|
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||||
|
return content[13:]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseProgress(content string) ImageProgressEvent {
|
||||||
|
var step, total int
|
||||||
|
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||||
|
return ImageProgressEvent{Step: step, Total: total}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||||
|
resp := ImageGenerationResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Data: make([]ImageData, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageBase64 == "" {
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
if format == "url" {
|
||||||
|
// URL format not supported when using base64 transfer
|
||||||
|
resp.Data[0].B64JSON = imageBase64
|
||||||
|
} else {
|
||||||
|
resp.Data[0].B64JSON = imageBase64
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||||
|
// This allows routes.go to delegate image generation with minimal code.
|
||||||
|
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||||
|
opts := api.Options{}
|
||||||
|
|
||||||
|
// Schedule runner
|
||||||
|
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build completion request
|
||||||
|
completionReq := llm.CompletionRequest{
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: &opts,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream responses via channel
|
||||||
|
ch := make(chan any)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||||
|
ch <- GenerateResponse{
|
||||||
|
Model: modelName,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Response: resp.Content,
|
||||||
|
Done: resp.Done,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// Log error but don't block - channel is already being consumed
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
streamFn(c, ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||||
|
type GenerateResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Response string `json:"response"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
}
|
||||||
31
x/imagegen/api/types.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// Package api provides OpenAI-compatible image generation API types.
|
||||||
|
package api
|
||||||
|
|
||||||
|
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||||
|
type ImageGenerationRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||||
|
type ImageGenerationResponse struct {
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []ImageData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageData contains the generated image data.
|
||||||
|
type ImageData struct {
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
B64JSON string `json:"b64_json,omitempty"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||||
|
type ImageProgressEvent struct {
|
||||||
|
Step int `json:"step"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ package imagegen
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -38,17 +39,77 @@ func DefaultOptions() ImageGenOptions {
|
|||||||
return ImageGenOptions{
|
return ImageGenOptions{
|
||||||
Width: 1024,
|
Width: 1024,
|
||||||
Height: 1024,
|
Height: 1024,
|
||||||
Steps: 0, // 0 means model default
|
Steps: 9,
|
||||||
Seed: 0, // 0 means random
|
Seed: 0, // 0 means random
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelInfo contains metadata about an image generation model.
|
||||||
|
type ModelInfo struct {
|
||||||
|
Architecture string
|
||||||
|
ParameterCount int64
|
||||||
|
Quantization string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelInfo returns metadata about an image generation model.
|
||||||
|
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||||
|
manifest, err := LoadManifest(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info := &ModelInfo{}
|
||||||
|
|
||||||
|
// Read model_index.json for architecture, parameter count, and quantization
|
||||||
|
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||||
|
var index struct {
|
||||||
|
Architecture string `json:"architecture"`
|
||||||
|
ParameterCount int64 `json:"parameter_count"`
|
||||||
|
Quantization string `json:"quantization"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &index) == nil {
|
||||||
|
info.Architecture = index.Architecture
|
||||||
|
info.ParameterCount = index.ParameterCount
|
||||||
|
info.Quantization = index.Quantization
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: detect quantization from tensor names if not in config
|
||||||
|
if info.Quantization == "" {
|
||||||
|
for _, layer := range manifest.Manifest.Layers {
|
||||||
|
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||||
|
info.Quantization = "FP8"
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info.Quantization == "" {
|
||||||
|
info.Quantization = "BF16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: estimate parameter count if not in config
|
||||||
|
if info.ParameterCount == 0 {
|
||||||
|
var totalSize int64
|
||||||
|
for _, layer := range manifest.Manifest.Layers {
|
||||||
|
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||||
|
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||||
|
totalSize += layer.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Assume BF16 (2 bytes/param) as rough estimate
|
||||||
|
info.ParameterCount = totalSize / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterFlags adds image generation flags to the given command.
|
// RegisterFlags adds image generation flags to the given command.
|
||||||
// Flags are hidden since they only apply to image generation models.
|
// Flags are hidden since they only apply to image generation models.
|
||||||
func RegisterFlags(cmd *cobra.Command) {
|
func RegisterFlags(cmd *cobra.Command) {
|
||||||
cmd.Flags().Int("width", 1024, "Image width")
|
cmd.Flags().Int("width", 1024, "Image width")
|
||||||
cmd.Flags().Int("height", 1024, "Image height")
|
cmd.Flags().Int("height", 1024, "Image height")
|
||||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||||
cmd.Flags().String("negative", "", "Negative prompt")
|
cmd.Flags().String("negative", "", "Negative prompt")
|
||||||
cmd.Flags().MarkHidden("width")
|
cmd.Flags().MarkHidden("width")
|
||||||
@@ -91,18 +152,23 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateImageWithOptions generates an image with the given options.
|
// generateImageWithOptions generates an image with the given options.
|
||||||
// Note: opts are currently unused as the native API doesn't support size parameters.
|
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||||
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
|
|
||||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build request with image gen options encoded in Options fields
|
||||||
|
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||||
req := &api.GenerateRequest{
|
req := &api.GenerateRequest{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
|
Options: map[string]any{
|
||||||
|
"num_ctx": opts.Width,
|
||||||
|
"num_gpu": opts.Height,
|
||||||
|
"num_predict": opts.Steps,
|
||||||
|
"seed": opts.Seed,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if keepAlive != nil {
|
if keepAlive != nil {
|
||||||
req.KeepAlive = keepAlive
|
req.KeepAlive = keepAlive
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||||
@@ -49,7 +48,7 @@ func main() {
|
|||||||
// Image generation params
|
// Image generation params
|
||||||
width := flag.Int("width", 1024, "Image width")
|
width := flag.Int("width", 1024, "Image width")
|
||||||
height := flag.Int("height", 1024, "Image height")
|
height := flag.Int("height", 1024, "Image height")
|
||||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
steps := flag.Int("steps", 9, "Denoising steps")
|
||||||
seed := flag.Int64("seed", 42, "Random seed")
|
seed := flag.Int64("seed", 42, "Random seed")
|
||||||
out := flag.String("output", "output.png", "Output path")
|
out := flag.String("output", "output.png", "Output path")
|
||||||
|
|
||||||
|
|||||||
@@ -175,63 +175,3 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelInfo contains metadata about an image generation model.
|
|
||||||
type ModelInfo struct {
|
|
||||||
Architecture string
|
|
||||||
ParameterCount int64
|
|
||||||
Quantization string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModelInfo returns metadata about an image generation model.
|
|
||||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
|
||||||
manifest, err := LoadManifest(modelName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
info := &ModelInfo{}
|
|
||||||
|
|
||||||
// Read model_index.json for architecture, parameter count, and quantization
|
|
||||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
|
||||||
var index struct {
|
|
||||||
Architecture string `json:"architecture"`
|
|
||||||
ParameterCount int64 `json:"parameter_count"`
|
|
||||||
Quantization string `json:"quantization"`
|
|
||||||
}
|
|
||||||
if json.Unmarshal(data, &index) == nil {
|
|
||||||
info.Architecture = index.Architecture
|
|
||||||
info.ParameterCount = index.ParameterCount
|
|
||||||
info.Quantization = index.Quantization
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: detect quantization from tensor names if not in config
|
|
||||||
if info.Quantization == "" {
|
|
||||||
for _, layer := range manifest.Manifest.Layers {
|
|
||||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
|
||||||
info.Quantization = "FP8"
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if info.Quantization == "" {
|
|
||||||
info.Quantization = "BF16"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: estimate parameter count if not in config
|
|
||||||
if info.ParameterCount == 0 {
|
|
||||||
var totalSize int64
|
|
||||||
for _, layer := range manifest.Manifest.Layers {
|
|
||||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
|
||||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
|
||||||
totalSize += layer.Size
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Assume BF16 (2 bytes/param) as rough estimate
|
|
||||||
info.ParameterCount = totalSize / 2
|
|
||||||
}
|
|
||||||
|
|
||||||
return info, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -95,3 +95,8 @@ func EstimateVRAM(modelName string) uint64 {
|
|||||||
}
|
}
|
||||||
return 21 * GB
|
return 21 * GB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasTensorLayers checks if the given model has tensor layers.
|
||||||
|
func HasTensorLayers(modelName string) bool {
|
||||||
|
return ResolveModelName(modelName) != ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -94,6 +94,13 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasTensorLayers(t *testing.T) {
|
||||||
|
// Non-existent model should return false
|
||||||
|
if HasTensorLayers("nonexistent-model") {
|
||||||
|
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResolveModelName(t *testing.T) {
|
func TestResolveModelName(t *testing.T) {
|
||||||
// Non-existent model should return empty string
|
// Non-existent model should return empty string
|
||||||
result := ResolveModelName("nonexistent-model")
|
result := ResolveModelName("nonexistent-model")
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
# MLX Memory Management
|
# MLX Memory Management
|
||||||
|
|
||||||
|
| This package will get consolidated with `x/ml/backend/mlx` in the future.
|
||||||
|
|
||||||
## Automatic Tracking
|
## Automatic Tracking
|
||||||
|
|
||||||
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
|
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
"github.com/ollama/ollama/x/imagegen/cache"
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||||
@@ -173,7 +172,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
|||||||
cfg.Height = 1024
|
cfg.Height = 1024
|
||||||
}
|
}
|
||||||
if cfg.Steps <= 0 {
|
if cfg.Steps <= 0 {
|
||||||
cfg.Steps = 50
|
cfg.Steps = 30
|
||||||
}
|
}
|
||||||
if cfg.CFGScale <= 0 {
|
if cfg.CFGScale <= 0 {
|
||||||
cfg.CFGScale = 4.0
|
cfg.CFGScale = 4.0
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
cfg.Height = 1024
|
cfg.Height = 1024
|
||||||
}
|
}
|
||||||
if cfg.Steps <= 0 {
|
if cfg.Steps <= 0 {
|
||||||
cfg.Steps = 9 // Z-Image turbo default
|
cfg.Steps = 9 // Turbo default
|
||||||
}
|
}
|
||||||
if cfg.CFGScale <= 0 {
|
if cfg.CFGScale <= 0 {
|
||||||
cfg.CFGScale = 4.0
|
cfg.CFGScale = 4.0
|
||||||
|
|||||||
@@ -136,8 +136,16 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Model applies its own defaults for width/height/steps
|
// Apply defaults
|
||||||
// Only seed needs to be set here if not provided
|
if req.Width <= 0 {
|
||||||
|
req.Width = 1024
|
||||||
|
}
|
||||||
|
if req.Height <= 0 {
|
||||||
|
req.Height = 1024
|
||||||
|
}
|
||||||
|
if req.Steps <= 0 {
|
||||||
|
req.Steps = 9
|
||||||
|
}
|
||||||
if req.Seed <= 0 {
|
if req.Seed <= 0 {
|
||||||
req.Seed = time.Now().UnixNano()
|
req.Seed = time.Now().UnixNano()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -26,11 +25,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||||
//
|
|
||||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
|
||||||
// like any other model. The plan is to eventually bring this into the llm/ package
|
|
||||||
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
|
|
||||||
// separate allows for independent iteration on image generation support.
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
@@ -43,6 +37,22 @@ type Server struct {
|
|||||||
lastErrLock sync.Mutex
|
lastErrLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// completionRequest is sent to the subprocess
|
||||||
|
type completionRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Width int32 `json:"width,omitempty"`
|
||||||
|
Height int32 `json:"height,omitempty"`
|
||||||
|
Steps int `json:"steps,omitempty"`
|
||||||
|
Seed int64 `json:"seed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// completionResponse is received from the subprocess
|
||||||
|
type completionResponse struct {
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
}
|
||||||
|
|
||||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||||
func NewServer(modelName string) (*Server, error) {
|
func NewServer(modelName string) (*Server, error) {
|
||||||
// Validate platform support before attempting to start
|
// Validate platform support before attempting to start
|
||||||
@@ -129,6 +139,7 @@ func NewServer(modelName string) (*Server, error) {
|
|||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
slog.Warn("image-runner", "msg", line)
|
slog.Warn("image-runner", "msg", line)
|
||||||
|
// Capture last error line for better error reporting
|
||||||
s.lastErrLock.Lock()
|
s.lastErrLock.Lock()
|
||||||
s.lastErr = line
|
s.lastErr = line
|
||||||
s.lastErrLock.Unlock()
|
s.lastErrLock.Unlock()
|
||||||
@@ -160,6 +171,7 @@ func (s *Server) ModelPath() string {
|
|||||||
return s.modelName
|
return s.modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load is called by the scheduler after the server is created.
|
||||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -192,16 +204,20 @@ func (s *Server) waitUntilRunning() error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case err := <-s.done:
|
case err := <-s.done:
|
||||||
// Include recent stderr lines for better error context
|
// Include last stderr line for better error context
|
||||||
errMsg := s.getLastErr()
|
s.lastErrLock.Lock()
|
||||||
if errMsg != "" {
|
lastErr := s.lastErr
|
||||||
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
|
s.lastErrLock.Unlock()
|
||||||
|
if lastErr != "" {
|
||||||
|
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
errMsg := s.getLastErr()
|
s.lastErrLock.Lock()
|
||||||
if errMsg != "" {
|
lastErr := s.lastErr
|
||||||
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
|
s.lastErrLock.Unlock()
|
||||||
|
if lastErr != "" {
|
||||||
|
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||||
}
|
}
|
||||||
return errors.New("timeout waiting for image runner to start")
|
return errors.New("timeout waiting for image runner to start")
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
@@ -213,39 +229,44 @@ func (s *Server) waitUntilRunning() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLastErr returns the last stderr line.
|
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||||
func (s *Server) getLastErr() string {
|
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||||
s.lastErrLock.Lock()
|
return nil
|
||||||
defer s.lastErrLock.Unlock()
|
|
||||||
return s.lastErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
|
// Completion generates an image from the prompt via the subprocess.
|
||||||
|
|
||||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
seed := req.Seed
|
// Build request
|
||||||
if seed == 0 {
|
creq := completionRequest{
|
||||||
seed = time.Now().UnixNano()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request for subprocess
|
|
||||||
creq := struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Width int32 `json:"width,omitempty"`
|
|
||||||
Height int32 `json:"height,omitempty"`
|
|
||||||
Seed int64 `json:"seed,omitempty"`
|
|
||||||
}{
|
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Width: req.Width,
|
Width: 1024,
|
||||||
Height: req.Height,
|
Height: 1024,
|
||||||
Seed: seed,
|
Steps: 9,
|
||||||
|
Seed: time.Now().UnixNano(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Options != nil {
|
||||||
|
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||||
|
creq.Width = int32(req.Options.NumCtx)
|
||||||
|
}
|
||||||
|
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||||
|
creq.Height = int32(req.Options.NumGPU)
|
||||||
|
}
|
||||||
|
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||||
|
creq.Steps = req.Options.NumPredict
|
||||||
|
}
|
||||||
|
if req.Options.Seed > 0 {
|
||||||
|
creq.Seed = int64(req.Options.Seed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode request body
|
||||||
body, err := json.Marshal(creq)
|
body, err := json.Marshal(creq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send request to subprocess
|
||||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -260,40 +281,30 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stream responses - use large buffer for base64 image data
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
// Parse subprocess response (has singular "image" field)
|
var cresp completionResponse
|
||||||
var raw struct {
|
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||||
Image string `json:"image,omitempty"`
|
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
Done bool `json:"done"`
|
|
||||||
Step int `json:"step,omitempty"`
|
|
||||||
Total int `json:"total,omitempty"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to llm.CompletionResponse
|
content := cresp.Content
|
||||||
cresp := llm.CompletionResponse{
|
// If this is the final response with an image, encode it in the content
|
||||||
Content: raw.Content,
|
if cresp.Done && cresp.Image != "" {
|
||||||
Done: raw.Done,
|
content = "IMAGE_BASE64:" + cresp.Image
|
||||||
Step: raw.Step,
|
|
||||||
Total: raw.Total,
|
|
||||||
}
|
|
||||||
if raw.Image != "" {
|
|
||||||
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
|
|
||||||
cresp.Image = data
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(llm.CompletionResponse{
|
||||||
|
Content: content,
|
||||||
|
Done: cresp.Done,
|
||||||
|
})
|
||||||
if cresp.Done {
|
if cresp.Done {
|
||||||
return nil
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,18 +346,22 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
|||||||
return s.vramSize
|
return s.vramSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Embedding is not supported for image generation models.
|
||||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||||
return nil, 0, errors.New("not supported")
|
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tokenize is not supported for image generation models.
|
||||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
return nil, errors.New("not supported")
|
return nil, errors.New("tokenize not supported for image generation models")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detokenize is not supported for image generation models.
|
||||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
return "", errors.New("not supported")
|
return "", errors.New("detokenize not supported for image generation models")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pid returns the subprocess PID.
|
||||||
func (s *Server) Pid() int {
|
func (s *Server) Pid() int {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -356,9 +371,17 @@ func (s *Server) Pid() int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GetPort() int { return s.port }
|
// GetPort returns the subprocess port.
|
||||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
func (s *Server) GetPort() int {
|
||||||
|
return s.port
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||||
|
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasExited returns true if the subprocess has exited.
|
||||||
func (s *Server) HasExited() bool {
|
func (s *Server) HasExited() bool {
|
||||||
select {
|
select {
|
||||||
case <-s.done:
|
case <-s.done:
|
||||||
|
|||||||
77
x/kvcache/cache.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
"github.com/ollama/ollama/x/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||||
|
ErrNotSupported = errors.New("model does not support operation")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cache interface {
|
||||||
|
// ** used by model implementations **
|
||||||
|
|
||||||
|
// SetLayer sets the active layer of the cache
|
||||||
|
SetLayer(layer int)
|
||||||
|
|
||||||
|
// Get returns the history of key and value tensors plus a mask
|
||||||
|
//
|
||||||
|
// The shape of the tensors is documented in the specific
|
||||||
|
// cache implementation used.
|
||||||
|
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||||
|
|
||||||
|
// Put stores a batch of key and value in the cache
|
||||||
|
//
|
||||||
|
// The shape of the tensors is documented in the specific
|
||||||
|
// cache implementation used.
|
||||||
|
Put(ctx ml.Context, key, value ml.Tensor)
|
||||||
|
|
||||||
|
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||||
|
// the output of the cache to work better with specific kernels. If not called,
|
||||||
|
// the backend settings will be used. This works well when calling Attention.
|
||||||
|
//
|
||||||
|
// The config can be overridden by models, especially if they require vanilla
|
||||||
|
// output when implementing their own version of attention. To do this, pass
|
||||||
|
// an empty ml.CacheConfig.
|
||||||
|
//
|
||||||
|
// Most models will not need to use this.
|
||||||
|
SetConfig(ml.CacheConfig)
|
||||||
|
|
||||||
|
// ** cache management **
|
||||||
|
|
||||||
|
// Init sets up runtime parameters.
|
||||||
|
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||||
|
// dtype: The data type for storing cache entries
|
||||||
|
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||||
|
// capacity: The number of cache entries to store, per sequence
|
||||||
|
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||||
|
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||||
|
|
||||||
|
// Close closes the cache and frees resources associated with it
|
||||||
|
Close()
|
||||||
|
|
||||||
|
// StartForward is called before the start of the model's forward pass.
|
||||||
|
// For each token in the coming batch, there must be a corresponding
|
||||||
|
// entry in positions and seqs. reserve is to preallocate memory
|
||||||
|
// without actually storing data in the cache.
|
||||||
|
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||||
|
|
||||||
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||||
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||||
|
|
||||||
|
// CanResume returns true if the cache can continue with the next token at
|
||||||
|
// the given position and sequence. Assumes that the caller has already
|
||||||
|
// verified the contents of the cache.
|
||||||
|
CanResume(seq int, pos int32) bool
|
||||||
|
|
||||||
|
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||||
|
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||||
|
//
|
||||||
|
// If an error occurs, the entire context for the sequence should be
|
||||||
|
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||||
|
Remove(seq int, beginIndex, endIndex int32) error
|
||||||
|
}
|
||||||
797
x/kvcache/causal.go
Normal file
@@ -0,0 +1,797 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
// import (
|
||||||
|
// "errors"
|
||||||
|
// "fmt"
|
||||||
|
// "log/slog"
|
||||||
|
// "math"
|
||||||
|
// "slices"
|
||||||
|
|
||||||
|
// "github.com/ollama/ollama/ml"
|
||||||
|
// "github.com/ollama/ollama/model/input"
|
||||||
|
// )
|
||||||
|
|
||||||
|
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||||
|
|
||||||
|
// // Causal cache stores K and V tensors according to their position in the
|
||||||
|
// // sequence. Returns the history and a mask for attending to past tokens
|
||||||
|
// //
|
||||||
|
// // The tensors are of shape embed dim, kv heads, batch size
|
||||||
|
// // The mask is of shape history size, batch size
|
||||||
|
// type Causal struct {
|
||||||
|
// DType ml.DType
|
||||||
|
|
||||||
|
// // swaWindowSize is the number of tokens that will be included in the mask
|
||||||
|
// // during attention operations. swaMemorySize is the number of tokens that
|
||||||
|
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
||||||
|
// // for unlimited or if sliding window attention is not being used.
|
||||||
|
// swaWindowSize int32
|
||||||
|
// swaMemorySize int32
|
||||||
|
|
||||||
|
// chunkSize int32
|
||||||
|
|
||||||
|
// opts CausalOptions
|
||||||
|
|
||||||
|
// // maxBatch is the largest batch that we might receive
|
||||||
|
// maxBatch int
|
||||||
|
|
||||||
|
// // config controls mostly backend-specific optimizations
|
||||||
|
// config *ml.CacheConfig
|
||||||
|
|
||||||
|
// // ** current forward pass **
|
||||||
|
|
||||||
|
// // size of the current batch
|
||||||
|
// curBatchSize int
|
||||||
|
|
||||||
|
// // locations for data storage for this batch
|
||||||
|
// curLoc ml.Tensor
|
||||||
|
|
||||||
|
// // mask of the cache as used by this batch
|
||||||
|
// curMask ml.Tensor
|
||||||
|
|
||||||
|
// // the active layer for Get and Put
|
||||||
|
// curLayer int
|
||||||
|
|
||||||
|
// // locations in the cache that are needed for this batch
|
||||||
|
// curCellRange cellRange
|
||||||
|
|
||||||
|
// // curSequences is the sequences corresponding to this pass's entries in the cache
|
||||||
|
// curSequences []int
|
||||||
|
|
||||||
|
// // curPositions is the positions corresponding to this pass's entries in the cache
|
||||||
|
// curPositions []int32
|
||||||
|
|
||||||
|
// // ** cache metadata **
|
||||||
|
|
||||||
|
// // for each possible location in the cache, stores the position and set of sequences
|
||||||
|
// // that reference the data there
|
||||||
|
// cells []cacheCell
|
||||||
|
|
||||||
|
// // maps from sequence to the range of locations where it is stored in the cache
|
||||||
|
// cellRanges map[int]cellRange
|
||||||
|
|
||||||
|
// // ** cache data storage **
|
||||||
|
|
||||||
|
// shiftFn shiftFn
|
||||||
|
// backend ml.Backend
|
||||||
|
// ctxs map[int]ml.Context
|
||||||
|
// keys, values map[int]ml.Tensor
|
||||||
|
|
||||||
|
// kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type cacheCell struct {
|
||||||
|
// pos int32
|
||||||
|
// sequences []int
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type cellRange struct {
|
||||||
|
// min int
|
||||||
|
// max int
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func NewCausalCache(shift shiftFn) *Causal {
|
||||||
|
// return &Causal{
|
||||||
|
// shiftFn: shift,
|
||||||
|
// ctxs: make(map[int]ml.Context),
|
||||||
|
// keys: make(map[int]ml.Tensor),
|
||||||
|
// values: make(map[int]ml.Tensor),
|
||||||
|
// kHeadDims: make(map[int]int),
|
||||||
|
// vHeadDims: make(map[int]int),
|
||||||
|
// numKVHeads: make(map[int]int),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||||
|
// return &Causal{
|
||||||
|
// swaWindowSize: windowSize,
|
||||||
|
// shiftFn: shift,
|
||||||
|
// ctxs: make(map[int]ml.Context),
|
||||||
|
// keys: make(map[int]ml.Tensor),
|
||||||
|
// values: make(map[int]ml.Tensor),
|
||||||
|
// kHeadDims: make(map[int]int),
|
||||||
|
// vHeadDims: make(map[int]int),
|
||||||
|
// numKVHeads: make(map[int]int),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
||||||
|
// return &Causal{
|
||||||
|
// swaWindowSize: windowSize,
|
||||||
|
// swaMemorySize: memorySize,
|
||||||
|
// shiftFn: shift,
|
||||||
|
// ctxs: make(map[int]ml.Context),
|
||||||
|
// keys: make(map[int]ml.Tensor),
|
||||||
|
// values: make(map[int]ml.Tensor),
|
||||||
|
// kHeadDims: make(map[int]int),
|
||||||
|
// vHeadDims: make(map[int]int),
|
||||||
|
// numKVHeads: make(map[int]int),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||||
|
// return &Causal{
|
||||||
|
// chunkSize: chunkSize,
|
||||||
|
// shiftFn: shift,
|
||||||
|
// ctxs: make(map[int]ml.Context),
|
||||||
|
// keys: make(map[int]ml.Tensor),
|
||||||
|
// values: make(map[int]ml.Tensor),
|
||||||
|
// kHeadDims: make(map[int]int),
|
||||||
|
// vHeadDims: make(map[int]int),
|
||||||
|
// numKVHeads: make(map[int]int),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
// if c.config == nil {
|
||||||
|
// var config ml.CacheConfig
|
||||||
|
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
|
// config = cc.CacheConfig()
|
||||||
|
// }
|
||||||
|
// c.config = &config
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if c.config.CachePadding == 0 {
|
||||||
|
// c.config.CachePadding = 1
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if c.config.MaskBatchPadding == 0 {
|
||||||
|
// c.config.MaskBatchPadding = 1
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // TODO what types do we handle here?
|
||||||
|
// // if c.config.MaskDType == ml.DTypeOther {
|
||||||
|
// // c.config.MaskDType = ml.DTypeFloat32
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// if c.swaWindowSize == 0 {
|
||||||
|
// c.swaWindowSize = math.MaxInt32
|
||||||
|
// }
|
||||||
|
// if c.swaMemorySize == 0 {
|
||||||
|
// c.swaMemorySize = c.swaWindowSize
|
||||||
|
// }
|
||||||
|
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
|
||||||
|
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
|
||||||
|
// // causing a cache break. As an optimization, only do this when we have parallel sequences
|
||||||
|
// // because the extra token will live in the batch buffer and won't get overwritten if we
|
||||||
|
// // only have a single sequence.
|
||||||
|
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
||||||
|
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
||||||
|
// }
|
||||||
|
// if int(c.swaMemorySize) >= capacity {
|
||||||
|
// c.swaMemorySize = math.MaxInt32
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if c.swaMemorySize < c.swaWindowSize {
|
||||||
|
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// var cacheSize int
|
||||||
|
// if c.swaMemorySize == math.MaxInt32 {
|
||||||
|
// cacheSize = maxSequences * capacity
|
||||||
|
// } else {
|
||||||
|
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
||||||
|
// }
|
||||||
|
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
|
// c.cells = make([]cacheCell, cacheSize)
|
||||||
|
|
||||||
|
// c.DType = dtype
|
||||||
|
// c.cellRanges = make(map[int]cellRange)
|
||||||
|
// c.backend = backend
|
||||||
|
// c.maxBatch = maxBatch
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||||
|
// if c.config != nil {
|
||||||
|
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.config = &config
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) Close() {
|
||||||
|
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||||
|
// for _, ctx := range c.ctxs {
|
||||||
|
// ctx.Close()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||||
|
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
|
||||||
|
// // panic("XXX Causal.StartForward")
|
||||||
|
// c.curBatchSize = len(batch.Positions)
|
||||||
|
// c.curSequences = batch.Sequences
|
||||||
|
// c.curPositions = batch.Positions
|
||||||
|
// c.opts.Except = nil
|
||||||
|
|
||||||
|
// var locs []int32
|
||||||
|
// if !reserve {
|
||||||
|
// c.updateSlidingWindow()
|
||||||
|
|
||||||
|
// var err error
|
||||||
|
// locs, err = c.findLocs()
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
|
||||||
|
|
||||||
|
// for i, pos := range batch.Positions {
|
||||||
|
// seq := batch.Sequences[i]
|
||||||
|
// loc := int(locs[i])
|
||||||
|
|
||||||
|
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
|
// seqRange, ok := c.cellRanges[seq]
|
||||||
|
// if !ok {
|
||||||
|
// seqRange = newRange()
|
||||||
|
// }
|
||||||
|
|
||||||
|
// seqRange.min = min(seqRange.min, loc)
|
||||||
|
// c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||||
|
|
||||||
|
// seqRange.max = max(seqRange.max, loc)
|
||||||
|
// c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||||
|
|
||||||
|
// c.cellRanges[seq] = seqRange
|
||||||
|
// }
|
||||||
|
// } else {
|
||||||
|
// // If we are reserving memory, don't update any of the cache metadata but set the size
|
||||||
|
// // to the worst case.
|
||||||
|
// locs = make([]int32, c.curBatchSize)
|
||||||
|
// for i := range locs {
|
||||||
|
// locs[i] = int32(i)
|
||||||
|
// }
|
||||||
|
// c.curCellRange.min = 0
|
||||||
|
// c.curCellRange.max = len(c.cells) - 1
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // XXX Building up the locs for what's already processed (if any)
|
||||||
|
// dummyLocs := []int{}
|
||||||
|
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||||
|
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
|
// for i := range c.curBatchSize {
|
||||||
|
// enabled := !slices.Contains(c.opts.Except, i)
|
||||||
|
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
|
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
|
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
|
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||||
|
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||||
|
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
|
// } else {
|
||||||
|
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
|
||||||
|
// dummyLocs = append(dummyLocs, i)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
|
||||||
|
|
||||||
|
// slog.Info("XXX Causal.StartForward", "locs", locs)
|
||||||
|
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||||
|
// c.curMask = c.buildMask(ctx)
|
||||||
|
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func newRange() cellRange {
|
||||||
|
// return cellRange{
|
||||||
|
// min: math.MaxInt,
|
||||||
|
// max: 0,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Returns a slice of locations where each token in the batch should be stored
|
||||||
|
// func (c *Causal) findLocs() ([]int32, error) {
|
||||||
|
// loc := make([]int32, 0, c.curBatchSize)
|
||||||
|
|
||||||
|
// for i := range c.cells {
|
||||||
|
// if len(c.cells[i].sequences) == 0 {
|
||||||
|
// loc = append(loc, int32(i))
|
||||||
|
// if len(loc) >= c.curBatchSize {
|
||||||
|
// return loc, nil
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) updateSlidingWindow() {
|
||||||
|
// c.curCellRange = newRange()
|
||||||
|
|
||||||
|
// if c.swaMemorySize == math.MaxInt32 {
|
||||||
|
// for _, seq := range c.curSequences {
|
||||||
|
// if seqRange, ok := c.cellRanges[seq]; ok {
|
||||||
|
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
||||||
|
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type lowestPosition struct {
|
||||||
|
// pos int32
|
||||||
|
// curBatch bool
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // create a map of unique sequences to the lowest position in that sequence
|
||||||
|
// lowestPos := make(map[int]lowestPosition)
|
||||||
|
// for i := range c.curPositions {
|
||||||
|
// seq := c.curSequences[i]
|
||||||
|
|
||||||
|
// lowest, ok := lowestPos[seq]
|
||||||
|
// if !ok {
|
||||||
|
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
||||||
|
// } else if c.curPositions[i] < lowest.pos {
|
||||||
|
// lowest.pos = c.curPositions[i]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// lowestPos[seq] = lowest
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // for any sequences are not part of this batch, clean up any tokens
|
||||||
|
// // that are no longer needed after the processing of the previous
|
||||||
|
// // batch
|
||||||
|
// for seq, seqRange := range c.cellRanges {
|
||||||
|
// if _, ok := lowestPos[seq]; !ok {
|
||||||
|
// var last int32
|
||||||
|
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||||
|
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
// last = max(last, c.cells[i].pos)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // delete any entries that are beyond the window of the oldest position in the sequence
|
||||||
|
// for seq, lowest := range lowestPos {
|
||||||
|
// oldRange, ok := c.cellRanges[seq]
|
||||||
|
// if !ok {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
|
||||||
|
// newRange := newRange()
|
||||||
|
|
||||||
|
// for i := oldRange.min; i <= oldRange.max; i++ {
|
||||||
|
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
||||||
|
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
// } else {
|
||||||
|
// newRange.min = min(newRange.min, i)
|
||||||
|
// newRange.max = max(newRange.max, i)
|
||||||
|
// }
|
||||||
|
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
||||||
|
// c.curCellRange.min = min(c.curCellRange.min, i)
|
||||||
|
// c.curCellRange.max = max(c.curCellRange.max, i)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.cellRanges[seq] = newRange
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func roundDown(length, pad int) int {
|
||||||
|
// return (length / pad) * pad
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func roundUp(length, pad int) int {
|
||||||
|
// return ((length + pad - 1) / pad) * pad
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Builds a mask of history x batch indicating whether for each token in the batch the
|
||||||
|
// // token in the history should apply. This is based on both the sequence and causality (the
|
||||||
|
// // position of the history is not ahead of the token in the batch).
|
||||||
|
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||||
|
// // Align and pad the two dimensions as required by the backend
|
||||||
|
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||||
|
|
||||||
|
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||||
|
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
|
// length := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
|
||||||
|
// mask := make([]float32, batchSize*length)
|
||||||
|
|
||||||
|
// for i := range c.curBatchSize {
|
||||||
|
// enabled := !slices.Contains(c.opts.Except, i)
|
||||||
|
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
|
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
|
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
|
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||||
|
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||||
|
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||||
|
// // has already been masked out because the sequence doesn't match.
|
||||||
|
// for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||||
|
// mask[i] = float32(math.Inf(-1))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
|
||||||
|
|
||||||
|
// // if c.config.MaskDType != ml.DTypeFloat32 {
|
||||||
|
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
|
||||||
|
|
||||||
|
// return maskTensor
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) SetLayer(layer int) {
|
||||||
|
// c.curLayer = layer
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type CausalOptions struct {
|
||||||
|
// // Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||||
|
// Except []int
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // SetCausal disables causal mask generation for a particular range of indicies in
|
||||||
|
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||||
|
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||||
|
// if !slices.Equal(c.opts.Except, opts.Except) {
|
||||||
|
// c.opts = opts
|
||||||
|
// if ctx != nil {
|
||||||
|
// c.curMask = c.buildMask(ctx)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
// key := c.keys[c.curLayer]
|
||||||
|
// value := c.values[c.curLayer]
|
||||||
|
|
||||||
|
// kHeadDim := c.kHeadDims[c.curLayer]
|
||||||
|
// vHeadDim := c.vHeadDims[c.curLayer]
|
||||||
|
// numKVHeads := c.numKVHeads[c.curLayer]
|
||||||
|
// // rowSize := numKVHeads * c.curBatchSize
|
||||||
|
// // cachedSize := c.curMask.Dim(1)
|
||||||
|
// cachedSize := c.curLoc.Dim(0)
|
||||||
|
// // kCellSize := kHeadDim * numKVHeads
|
||||||
|
// // vCellSize := vHeadDim * numKVHeads
|
||||||
|
|
||||||
|
// slog.Info("XXX Causal.Get full cache", "key", key)
|
||||||
|
// slog.Info("XXX Causal.Get full cache", "value", value)
|
||||||
|
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
|
||||||
|
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
|
||||||
|
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
|
||||||
|
// // panic("XXX")
|
||||||
|
|
||||||
|
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||||
|
// // panic("full cache value")
|
||||||
|
|
||||||
|
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||||
|
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||||
|
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||||
|
|
||||||
|
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
|
||||||
|
// // panic("XXX")
|
||||||
|
|
||||||
|
// // if c.config.PermutedV {
|
||||||
|
// // panic("permuted")
|
||||||
|
// // // TODO not converted
|
||||||
|
// // vHeadDim := value.Dim(1)
|
||||||
|
// // elemSize := value.Stride(2)
|
||||||
|
|
||||||
|
// // value = value.AsStrided(ctx,
|
||||||
|
// // []int{numKVHeads, vHeadDim, cachedSize},
|
||||||
|
// // []int{value.Stride(0), value.Stride(1)},
|
||||||
|
// // elemSize*c.curCellRange.min,
|
||||||
|
// // )
|
||||||
|
// // } else {
|
||||||
|
// // vHeadDim := c.vHeadDims[c.curLayer]
|
||||||
|
// // rowSize := value.Stride(2)
|
||||||
|
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
|
||||||
|
// // panic("XXX")
|
||||||
|
|
||||||
|
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||||
|
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||||
|
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||||
|
|
||||||
|
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
|
||||||
|
// // panic("XXX")
|
||||||
|
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
|
||||||
|
// // // the 1 becomes trailing and messes up later operations
|
||||||
|
// // // This isn't the right solution, but works around it...
|
||||||
|
// // if c.curMask.Dim(1) == 1 {
|
||||||
|
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
|
||||||
|
// // }
|
||||||
|
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||||
|
// // fmt.Fprintln(os.Stderr, value.ToString())
|
||||||
|
// // panic("XXX")
|
||||||
|
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
|
||||||
|
|
||||||
|
// return key, value, c.curMask
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
// kHeadDim := key.Dim(3)
|
||||||
|
// vHeadDim := value.Dim(3)
|
||||||
|
// numKVHeads := key.Dim(1)
|
||||||
|
// batchSize := key.Dim(2)
|
||||||
|
// kCellSize := kHeadDim * numKVHeads
|
||||||
|
// vCellSize := vHeadDim * numKVHeads
|
||||||
|
|
||||||
|
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
|
||||||
|
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
|
||||||
|
// // panic("XXX")
|
||||||
|
|
||||||
|
// if c.curBatchSize != batchSize {
|
||||||
|
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
|
||||||
|
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
|
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||||
|
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
|
||||||
|
|
||||||
|
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
|
||||||
|
// c.kHeadDims[c.curLayer] = kHeadDim
|
||||||
|
// c.vHeadDims[c.curLayer] = vHeadDim
|
||||||
|
// c.numKVHeads[c.curLayer] = numKVHeads
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, ok := c.values[c.curLayer]; !ok {
|
||||||
|
// // if c.config.PermutedV {
|
||||||
|
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
|
||||||
|
// // } else {
|
||||||
|
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
|
||||||
|
// // }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||||
|
|
||||||
|
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
|
||||||
|
// // panic("XXX")
|
||||||
|
// // curLoc := 0 // TODO c.curLoc is now a tensor
|
||||||
|
// // kSize := numKVHeads * kHeadDim
|
||||||
|
// // vSize := numKVHeads * vHeadDim
|
||||||
|
// // start := []int{int(curLoc), 0}
|
||||||
|
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
|
||||||
|
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
|
||||||
|
// // strides := []int{1, 1}
|
||||||
|
|
||||||
|
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
|
||||||
|
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
|
||||||
|
|
||||||
|
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
|
||||||
|
|
||||||
|
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
|
||||||
|
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
|
||||||
|
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
|
||||||
|
// // panic("input value")
|
||||||
|
|
||||||
|
// // fmt.Fprintln(os.Stderr, t.ToString())
|
||||||
|
// // panic("XXX")
|
||||||
|
|
||||||
|
// // if c.config.PermutedV {
|
||||||
|
// // panic("permuted")
|
||||||
|
// // // TODO not adjusted
|
||||||
|
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||||
|
// // value = value.Transpose(ctx, 2, 0, 1, 3)
|
||||||
|
|
||||||
|
// // valueCache := c.values[c.curLayer]
|
||||||
|
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||||
|
|
||||||
|
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
|
||||||
|
// // } else {
|
||||||
|
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||||
|
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
|
||||||
|
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
|
||||||
|
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
|
||||||
|
|
||||||
|
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
|
||||||
|
// // }
|
||||||
|
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
|
||||||
|
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
|
||||||
|
// // panic("XXX")
|
||||||
|
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
// seqRange := newRange()
|
||||||
|
|
||||||
|
// for i := range c.cells {
|
||||||
|
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
||||||
|
// if slices.Contains(c.cells[i].sequences, dstSeq) {
|
||||||
|
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
||||||
|
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
||||||
|
// if i < seqRange.min {
|
||||||
|
// seqRange.min = i
|
||||||
|
// }
|
||||||
|
// if i > seqRange.max {
|
||||||
|
// seqRange.max = i
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.cellRanges[dstSeq] = seqRange
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||||
|
// if c.swaMemorySize == math.MaxInt32 {
|
||||||
|
// return true
|
||||||
|
// }
|
||||||
|
|
||||||
|
// seqRange, ok := c.cellRanges[seq]
|
||||||
|
// if !ok {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // for sliding window, check that the window of the new sequence is contained in
|
||||||
|
// // the window of what we are storing
|
||||||
|
// var first int32 = math.MaxInt32
|
||||||
|
// var last int32 = -1
|
||||||
|
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||||
|
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
// first = min(first, c.cells[i].pos)
|
||||||
|
// last = max(last, c.cells[i].pos)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if last == -1 {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
|
||||||
|
// posWindowStart := max(0, pos-c.swaWindowSize)
|
||||||
|
// return posWindowStart >= first && pos <= last+1
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||||
|
// if c.shiftFn == nil {
|
||||||
|
// return ErrNotSupported
|
||||||
|
// }
|
||||||
|
|
||||||
|
// seqRange := c.cellRanges[seq]
|
||||||
|
|
||||||
|
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||||
|
// size := min(seqRange.max-start+1, c.maxBatch)
|
||||||
|
// offsets := make([]int32, size)
|
||||||
|
|
||||||
|
// var batchFirst, batchLast int
|
||||||
|
|
||||||
|
// batchFirst = -1
|
||||||
|
// for i := range offsets {
|
||||||
|
// cell := c.cells[start+i]
|
||||||
|
|
||||||
|
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||||
|
// offsets[i] = offset
|
||||||
|
// if batchFirst < 0 {
|
||||||
|
// batchFirst = i
|
||||||
|
// }
|
||||||
|
// batchLast = i
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if batchFirst < 0 {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
|
||||||
|
// offsets = offsets[batchFirst : batchLast+1]
|
||||||
|
|
||||||
|
// slog.Info("XXX Causal.shift creating new temporary context")
|
||||||
|
// ctx := c.backend.NewContext()
|
||||||
|
// kShift := ctx.Input().FromInts(offsets, len(offsets))
|
||||||
|
|
||||||
|
// for i, key := range c.keys {
|
||||||
|
// if key == nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
|
||||||
|
// kHeadDim := key.Dim(2)
|
||||||
|
// numKVHeads := key.Dim(1)
|
||||||
|
// rowSize := key.Stride(0)
|
||||||
|
|
||||||
|
// key = key.AsStrided(ctx,
|
||||||
|
// []int{len(offsets), numKVHeads, kHeadDim},
|
||||||
|
// []int{key.Stride(0), key.Stride(1)},
|
||||||
|
// rowSize*(start+batchFirst),
|
||||||
|
// )
|
||||||
|
|
||||||
|
// roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||||
|
// if err != nil {
|
||||||
|
// ctx.Close()
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// ctx.Forward(roped.Copy(ctx, key))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// ctx.Compute()
|
||||||
|
// ctx.Close()
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||||
|
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||||
|
// // should return an error, which will trigger the runner to evaluate the full history and
|
||||||
|
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||||
|
// // results in use after free, so we don't do it for now.
|
||||||
|
|
||||||
|
// var offset int32
|
||||||
|
// if endIndex != math.MaxInt32 {
|
||||||
|
// offset = beginIndex - endIndex
|
||||||
|
// }
|
||||||
|
|
||||||
|
// seqRange := newRange()
|
||||||
|
|
||||||
|
// for i := range c.cells {
|
||||||
|
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
||||||
|
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
// } else {
|
||||||
|
// if c.cells[i].pos >= endIndex {
|
||||||
|
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||||
|
// return errors.New("shifting cells shared by multiple sequences not supported")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.cells[i].pos += offset
|
||||||
|
// }
|
||||||
|
// if i < seqRange.min {
|
||||||
|
// seqRange.min = i
|
||||||
|
// }
|
||||||
|
// if i > seqRange.max {
|
||||||
|
// seqRange.max = i
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if seqRange == newRange() {
|
||||||
|
// delete(c.cellRanges, seq)
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.cellRanges[seq] = seqRange
|
||||||
|
|
||||||
|
// if endIndex != math.MaxInt32 {
|
||||||
|
// err := c.shift(seq, endIndex+offset, offset)
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
973
x/kvcache/causal_test.go
Normal file
@@ -0,0 +1,973 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
// import (
|
||||||
|
// "fmt"
|
||||||
|
// "math"
|
||||||
|
// "slices"
|
||||||
|
// "testing"
|
||||||
|
|
||||||
|
// "github.com/ollama/ollama/ml"
|
||||||
|
// "github.com/ollama/ollama/model/input"
|
||||||
|
// )
|
||||||
|
|
||||||
|
// type testCase struct {
|
||||||
|
// name string
|
||||||
|
// in []float32
|
||||||
|
// inShape []int
|
||||||
|
// seqs []int
|
||||||
|
// pos []int32
|
||||||
|
// expected []float32
|
||||||
|
// expectedShape []int
|
||||||
|
// expectedMask []float32
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
|
||||||
|
// t.Helper()
|
||||||
|
// for _, permuted := range []bool{false, true} {
|
||||||
|
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
|
||||||
|
// fn(t, &testBackend{permutedV: permuted})
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestStore(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// cache := NewCausalCache(nil)
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// tests := []testCase{
|
||||||
|
// {
|
||||||
|
// name: "FirstBatch",
|
||||||
|
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||||
|
// inShape: []int{2, 3, 4},
|
||||||
|
// seqs: []int{0, 0, 0, 0},
|
||||||
|
// pos: []int32{0, 1, 2, 3},
|
||||||
|
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||||
|
// expectedShape: []int{2, 3, 4},
|
||||||
|
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "SecondBatch",
|
||||||
|
// in: []float32{115, 215, 125, 225, 135, 235},
|
||||||
|
// inShape: []int{2, 3, 1},
|
||||||
|
// seqs: []int{0},
|
||||||
|
// pos: []int32{4},
|
||||||
|
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||||
|
// expectedShape: []int{2, 3, 5},
|
||||||
|
// expectedMask: []float32{0, 0, 0, 0, 0},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestSWA(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// cache := NewSWACache(1, nil)
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// x := float32(math.Inf(-1))
|
||||||
|
|
||||||
|
// tests := []testCase{
|
||||||
|
// {
|
||||||
|
// name: "FirstBatch",
|
||||||
|
// in: []float32{1, 2, 3, 4},
|
||||||
|
// inShape: []int{1, 1, 4},
|
||||||
|
// seqs: []int{0, 0, 0, 0},
|
||||||
|
// pos: []int32{0, 1, 2, 3},
|
||||||
|
// expected: []float32{1, 2, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 4},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x, x, x,
|
||||||
|
// 0, 0, x, x,
|
||||||
|
// x, 0, 0, x,
|
||||||
|
// x, x, 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "SecondBatch",
|
||||||
|
// in: []float32{5, 6},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 0},
|
||||||
|
// pos: []int32{4, 5},
|
||||||
|
// expected: []float32{5, 6, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 4},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x, x, 0,
|
||||||
|
// 0, 0, x, x,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestSWASeparateBatches(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// cache := NewSWACache(1, nil)
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
||||||
|
|
||||||
|
// x := float32(math.Inf(-1))
|
||||||
|
|
||||||
|
// tests := []testCase{
|
||||||
|
// {
|
||||||
|
// name: "First seq 0",
|
||||||
|
// in: []float32{1, 2},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 0},
|
||||||
|
// pos: []int32{0, 1},
|
||||||
|
// expected: []float32{1, 2},
|
||||||
|
// expectedShape: []int{1, 1, 2},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x,
|
||||||
|
// 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "Second seq 0",
|
||||||
|
// in: []float32{3, 4},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 0},
|
||||||
|
// pos: []int32{2, 3},
|
||||||
|
// expected: []float32{2, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 3},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, 0, x,
|
||||||
|
// x, 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "First seq 1",
|
||||||
|
// in: []float32{5, 6},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{1, 1},
|
||||||
|
// pos: []int32{0, 1},
|
||||||
|
// expected: []float32{5, 6},
|
||||||
|
// expectedShape: []int{1, 1, 2},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x,
|
||||||
|
// 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "Second seq 1",
|
||||||
|
// in: []float32{7, 8},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{1, 1},
|
||||||
|
// pos: []int32{2, 3},
|
||||||
|
// expected: []float32{6, 3, 4, 7, 8},
|
||||||
|
// expectedShape: []int{1, 1, 5},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x, x, 0, x,
|
||||||
|
// x, x, x, 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "Third seq 0",
|
||||||
|
// in: []float32{9, 10},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 0},
|
||||||
|
// pos: []int32{4, 5},
|
||||||
|
// expected: []float32{9, 10, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 4},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x, x, 0,
|
||||||
|
// 0, 0, x, x,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestSWAMem(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// cache := NewSWAMemCache(1, 3, nil)
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// x := float32(math.Inf(-1))
|
||||||
|
|
||||||
|
// tests := []testCase{
|
||||||
|
// {
|
||||||
|
// name: "FirstBatch",
|
||||||
|
// in: []float32{1, 2, 3, 4},
|
||||||
|
// inShape: []int{1, 1, 4},
|
||||||
|
// seqs: []int{0, 0, 0, 0},
|
||||||
|
// pos: []int32{0, 1, 2, 3},
|
||||||
|
// expected: []float32{1, 2, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 4},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x, x, x,
|
||||||
|
// 0, 0, x, x,
|
||||||
|
// x, 0, 0, x,
|
||||||
|
// x, x, 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "SecondBatch",
|
||||||
|
// in: []float32{5, 6},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 0},
|
||||||
|
// pos: []int32{4, 5},
|
||||||
|
// expected: []float32{5, 2, 3, 4, 6},
|
||||||
|
// expectedShape: []int{1, 1, 5},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x, x, 0, x,
|
||||||
|
// 0, x, x, x, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestChunkedAttention(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// cache := NewChunkedAttentionCache(2, nil)
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// x := float32(math.Inf(-1))
|
||||||
|
|
||||||
|
// testCache(
|
||||||
|
// t, backend, cache,
|
||||||
|
// []testCase{
|
||||||
|
// {
|
||||||
|
// name: "FirstBatch",
|
||||||
|
// in: []float32{1, 2, 3, 4},
|
||||||
|
// inShape: []int{1, 1, 4},
|
||||||
|
// seqs: []int{0, 0, 0, 0},
|
||||||
|
// pos: []int32{0, 1, 2, 3},
|
||||||
|
// expected: []float32{1, 2, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 4},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x, x, x,
|
||||||
|
// 0, 0, x, x,
|
||||||
|
// x, x, 0, x,
|
||||||
|
// x, x, 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "SecondBatch",
|
||||||
|
// in: []float32{5, 6, 7},
|
||||||
|
// inShape: []int{1, 1, 3},
|
||||||
|
// seqs: []int{0, 0, 0},
|
||||||
|
// pos: []int32{4, 5, 6},
|
||||||
|
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
||||||
|
// expectedShape: []int{1, 1, 7},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// x, x, x, x, 0, x, x,
|
||||||
|
// x, x, x, x, 0, 0, x,
|
||||||
|
// x, x, x, x, x, x, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "ThirdBatch",
|
||||||
|
// in: []float32{8, 9},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 0},
|
||||||
|
// pos: []int32{7, 8},
|
||||||
|
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||||
|
// expectedShape: []int{1, 1, 9},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// x, x, x, x, x, x, 0, 0, x,
|
||||||
|
// x, x, x, x, x, x, x, x, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// )
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestSequences(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// cache := NewCausalCache(nil)
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// tests := []testCase{
|
||||||
|
// {
|
||||||
|
// name: "FirstBatch",
|
||||||
|
// in: []float32{1, 2, 3, 4},
|
||||||
|
// inShape: []int{1, 1, 4},
|
||||||
|
// seqs: []int{0, 0, 1, 1},
|
||||||
|
// pos: []int32{0, 1, 0, 1},
|
||||||
|
// expected: []float32{1, 2, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 4},
|
||||||
|
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "SecondBatch",
|
||||||
|
// in: []float32{5, 6},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 1},
|
||||||
|
// pos: []int32{2, 2},
|
||||||
|
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||||
|
// expectedShape: []int{1, 1, 6},
|
||||||
|
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestRemove(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
// return key.Add(ctx, shift), nil
|
||||||
|
// })
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// x := float32(math.Inf(-1))
|
||||||
|
|
||||||
|
// tests := []testCase{
|
||||||
|
// {
|
||||||
|
// name: "FirstBatch",
|
||||||
|
// in: []float32{1, 2, 3, 4},
|
||||||
|
// inShape: []int{1, 1, 4},
|
||||||
|
// seqs: []int{0, 0, 1, 1},
|
||||||
|
// pos: []int32{0, 1, 0, 1},
|
||||||
|
// expected: []float32{1, 2, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 4},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, x, x, x,
|
||||||
|
// 0, 0, x, x,
|
||||||
|
// x, x, 0, x,
|
||||||
|
// x, x, 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
// err := cache.Remove(0, 1, math.MaxInt32)
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// tests = []testCase{
|
||||||
|
// {
|
||||||
|
// name: "RemoveEnd",
|
||||||
|
// in: []float32{5, 6},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 1},
|
||||||
|
// pos: []int32{1, 2},
|
||||||
|
// expected: []float32{1, 5, 3, 4, 6},
|
||||||
|
// expectedShape: []int{1, 1, 5},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, 0, x, x, x,
|
||||||
|
// x, x, 0, 0, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
// err = cache.Remove(0, 0, 1)
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// tests = []testCase{
|
||||||
|
// {
|
||||||
|
// name: "RemoveMiddle",
|
||||||
|
// in: []float32{7, 8},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{0, 0},
|
||||||
|
// pos: []int32{1, 2},
|
||||||
|
// expected: []float32{7, 4, 3, 4, 6, 8},
|
||||||
|
// expectedShape: []int{1, 1, 6},
|
||||||
|
// expectedMask: []float32{
|
||||||
|
// 0, 0, x, x, x, x,
|
||||||
|
// 0, 0, x, x, x, 0,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestCopy(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// tests := []testCase{
|
||||||
|
// {
|
||||||
|
// name: "FirstBatch",
|
||||||
|
// in: []float32{1, 2, 3, 4},
|
||||||
|
// inShape: []int{1, 1, 4},
|
||||||
|
// seqs: []int{0, 0, 0, 0},
|
||||||
|
// pos: []int32{0, 1, 2, 3},
|
||||||
|
// expected: []float32{1, 2, 3, 4},
|
||||||
|
// expectedShape: []int{1, 1, 4},
|
||||||
|
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
// cache.CopyPrefix(0, 1, 2)
|
||||||
|
|
||||||
|
// tests = []testCase{
|
||||||
|
// {
|
||||||
|
// name: "Copy",
|
||||||
|
// in: []float32{5, 6},
|
||||||
|
// inShape: []int{1, 1, 2},
|
||||||
|
// seqs: []int{1, 1},
|
||||||
|
// pos: []int32{3, 4},
|
||||||
|
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||||
|
// expectedShape: []int{1, 1, 6},
|
||||||
|
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
// testCache(t, backend, cache, tests)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||||
|
// for _, test := range tests {
|
||||||
|
// t.Run(test.name, func(t *testing.T) {
|
||||||
|
// context := backend.NewContext()
|
||||||
|
// defer context.Close()
|
||||||
|
|
||||||
|
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// cache.SetLayer(0)
|
||||||
|
// tensor := context.FromFloats(test.in, test.inShape...)
|
||||||
|
// cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// out, _, mask := cache.Get(context)
|
||||||
|
|
||||||
|
// context.Forward(out, mask).Compute(out, mask)
|
||||||
|
|
||||||
|
// if !slices.Equal(out.Floats(), test.expected) {
|
||||||
|
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if !slices.Equal(out.Shape(), test.expectedShape) {
|
||||||
|
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||||
|
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestCanResume(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// windowSize := int32(4)
|
||||||
|
// cache := NewSWACache(windowSize, nil)
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// context := backend.NewContext()
|
||||||
|
// defer context.Close()
|
||||||
|
|
||||||
|
// err := cache.StartForward(context, input.Batch{
|
||||||
|
// Positions: []int32{0, 1, 2, 3, 4},
|
||||||
|
// Sequences: []int{0, 0, 0, 0, 0},
|
||||||
|
// }, false)
|
||||||
|
// if err != nil {
|
||||||
|
// t.Fatalf("StartForward failed: %v", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// cache.SetLayer(0)
|
||||||
|
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
||||||
|
// cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// // with window size 4, nothing has slid out of the window yet
|
||||||
|
// if !cache.CanResume(0, 0) {
|
||||||
|
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||||
|
// }
|
||||||
|
// if !cache.CanResume(0, 1) {
|
||||||
|
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||||
|
// }
|
||||||
|
// if !cache.CanResume(0, 2) {
|
||||||
|
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||||
|
// }
|
||||||
|
// if !cache.CanResume(0, 3) {
|
||||||
|
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||||
|
// }
|
||||||
|
// if !cache.CanResume(0, 4) {
|
||||||
|
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // shift window by adding position 5
|
||||||
|
// err = cache.StartForward(context, input.Batch{
|
||||||
|
// Positions: []int32{5},
|
||||||
|
// Sequences: []int{0},
|
||||||
|
// }, false)
|
||||||
|
// if err != nil {
|
||||||
|
// t.Fatalf("StartForward failed: %v", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// cache.SetLayer(0)
|
||||||
|
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
|
||||||
|
// cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// // only the latest position has overlapping windows
|
||||||
|
// if cache.CanResume(0, 0) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 1) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 2) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 3) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 4) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if !cache.CanResume(0, 5) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func TestCanResumeSWAMem(t *testing.T) {
|
||||||
|
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
|
// windowSize := int32(4)
|
||||||
|
// memSize := int32(5)
|
||||||
|
// cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||||
|
// defer cache.Close()
|
||||||
|
|
||||||
|
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
// context := backend.NewContext()
|
||||||
|
// defer context.Close()
|
||||||
|
|
||||||
|
// err := cache.StartForward(context, input.Batch{
|
||||||
|
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
||||||
|
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
||||||
|
// }, false)
|
||||||
|
// if err != nil {
|
||||||
|
// t.Fatalf("StartForward failed: %v", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// cache.SetLayer(0)
|
||||||
|
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
||||||
|
// cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// // shift window by adding position 7
|
||||||
|
// err = cache.StartForward(context, input.Batch{
|
||||||
|
// Positions: []int32{7},
|
||||||
|
// Sequences: []int{0},
|
||||||
|
// }, false)
|
||||||
|
// if err != nil {
|
||||||
|
// t.Fatalf("StartForward failed: %v", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// cache.SetLayer(0)
|
||||||
|
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
|
||||||
|
// cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// // only the latest position has overlapping windows
|
||||||
|
// if cache.CanResume(0, 0) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 1) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 2) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 3) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 4) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if cache.CanResume(0, 5) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
||||||
|
// }
|
||||||
|
// if !cache.CanResume(0, 6) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
||||||
|
// }
|
||||||
|
// if !cache.CanResume(0, 7) {
|
||||||
|
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type testBackend struct {
|
||||||
|
// ml.Backend
|
||||||
|
// permutedV bool
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (b *testBackend) NewContext() ml.Context {
|
||||||
|
// return &testContext{}
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (b *testBackend) NewContextSize(int) ml.Context {
|
||||||
|
// return &testContext{}
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (b *testBackend) CacheConfig() ml.CacheConfig {
|
||||||
|
// return ml.CacheConfig{PermutedV: b.permutedV}
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type testContext struct {
|
||||||
|
// ml.Context
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
// total := 0
|
||||||
|
|
||||||
|
// if len(shape) > 0 {
|
||||||
|
// total = 1
|
||||||
|
// for _, s := range shape {
|
||||||
|
// total *= s
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
// return c.Empty(dtype, shape...)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
|
||||||
|
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||||
|
|
||||||
|
// copy(t.data, s)
|
||||||
|
|
||||||
|
// return t
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
|
||||||
|
// f := make([]float32, len(s))
|
||||||
|
// for i := range f {
|
||||||
|
// f[i] = float32(s[i])
|
||||||
|
// }
|
||||||
|
|
||||||
|
// out := c.FromFloats(f, shape...)
|
||||||
|
// out.(*testTensor).dtype = ml.DTypeI32
|
||||||
|
|
||||||
|
// return out
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||||
|
// s := make([]float32, 0, int((stop-start)/step))
|
||||||
|
// for i := start; i < stop; i += step {
|
||||||
|
// s = append(s, i)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// out := c.FromFloats(s, len(s))
|
||||||
|
// out.(*testTensor).dtype = dtype
|
||||||
|
// return out
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *testContext) Input() ml.Context { return c }
|
||||||
|
// func (c *testContext) Layer(int) ml.Context { return c }
|
||||||
|
|
||||||
|
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||||
|
|
||||||
|
// func (c *testContext) Compute(...ml.Tensor) {}
|
||||||
|
|
||||||
|
// func (c *testContext) Reserve() {}
|
||||||
|
|
||||||
|
// func (c *testContext) MaxGraphNodes() int {
|
||||||
|
// return 10
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *testContext) Close() {}
|
||||||
|
|
||||||
|
// type testTensor struct {
|
||||||
|
// ml.Tensor
|
||||||
|
|
||||||
|
// dtype ml.DType
|
||||||
|
// elementSize int
|
||||||
|
// data []float32
|
||||||
|
// shape []int
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Dim(n int) int {
|
||||||
|
// return t.shape[n]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Stride(n int) int {
|
||||||
|
// stride := t.elementSize
|
||||||
|
// for i := range n {
|
||||||
|
// stride *= t.shape[i]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return stride
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Shape() []int {
|
||||||
|
// return t.shape
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) DType() ml.DType {
|
||||||
|
// return t.dtype
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Floats() []float32 {
|
||||||
|
// out := make([]float32, len(t.data))
|
||||||
|
// copy(out, t.data)
|
||||||
|
// return out
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
||||||
|
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||||
|
// for i := range out.data {
|
||||||
|
// out.data[i] = -t.data[i]
|
||||||
|
// }
|
||||||
|
// return out
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||||
|
|
||||||
|
// for i := range out.data {
|
||||||
|
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return out
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
|
// return &testTensor{
|
||||||
|
// dtype: t.dtype,
|
||||||
|
// elementSize: t.elementSize,
|
||||||
|
// data: t.data,
|
||||||
|
// shape: shape,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||||
|
// offset /= t.elementSize
|
||||||
|
|
||||||
|
// var s []int
|
||||||
|
|
||||||
|
// switch len(shape) {
|
||||||
|
// case 1:
|
||||||
|
// s = []int{shape[0]}
|
||||||
|
// case 3:
|
||||||
|
// s = []int{shape[0], shape[2]}
|
||||||
|
// case 5:
|
||||||
|
// s = []int{shape[0], shape[2], shape[4]}
|
||||||
|
// default:
|
||||||
|
// panic("unsupported number of dimensions")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// context := &testContext{}
|
||||||
|
|
||||||
|
// view := context.Empty(t.dtype, s...).(*testTensor)
|
||||||
|
// view.data = t.data[offset : offset+len(view.data)]
|
||||||
|
|
||||||
|
// return view
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||||
|
// if len(t.shape) > 4 || len(order) > 4 {
|
||||||
|
// panic("permute only supports up to 4 dimensions")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if len(order) != len(t.shape) && len(order) != 4 {
|
||||||
|
// panic("invalid number of dimensions for permute")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
|
||||||
|
// orderFull := append(make([]int, 0, 4), order...)
|
||||||
|
// for len(orderFull) < 4 {
|
||||||
|
// orderFull = append(orderFull, len(orderFull))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// seen := [4]bool{}
|
||||||
|
|
||||||
|
// shape4 := [4]int{1, 1, 1, 1}
|
||||||
|
// for i := 0; i < len(t.shape) && i < 4; i++ {
|
||||||
|
// shape4[i] = t.shape[i]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// newShape4 := [4]int{1, 1, 1, 1}
|
||||||
|
// for axis := range 4 {
|
||||||
|
// dst := orderFull[axis]
|
||||||
|
// if dst < 0 || dst >= 4 {
|
||||||
|
// panic("invalid axis for permute")
|
||||||
|
// }
|
||||||
|
// if seen[dst] {
|
||||||
|
// panic("duplicate axis for permute")
|
||||||
|
// }
|
||||||
|
// seen[dst] = true
|
||||||
|
// newShape4[dst] = shape4[axis]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// total := len(t.data)
|
||||||
|
// newData := make([]float32, total)
|
||||||
|
|
||||||
|
// if total > 0 {
|
||||||
|
// oldDims := shape4
|
||||||
|
// newDims := newShape4
|
||||||
|
|
||||||
|
// oldStride := [4]int{1, 1, 1, 1}
|
||||||
|
// newStride := [4]int{1, 1, 1, 1}
|
||||||
|
// for i := 1; i < 4; i++ {
|
||||||
|
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
|
||||||
|
// newStride[i] = newStride[i-1] * newDims[i-1]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// var coords [4]int
|
||||||
|
// var newCoords [4]int
|
||||||
|
|
||||||
|
// for idx := range total {
|
||||||
|
// remainder := idx
|
||||||
|
// for axis := range 4 {
|
||||||
|
// dim := oldDims[axis]
|
||||||
|
// if dim == 0 {
|
||||||
|
// coords[axis] = 0
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// coords[axis] = remainder % dim
|
||||||
|
// remainder /= dim
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for axis := range 4 {
|
||||||
|
// newCoords[orderFull[axis]] = coords[axis]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// newIndex := 0
|
||||||
|
// for axis := range 4 {
|
||||||
|
// if newDims[axis] == 0 {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// newIndex += newCoords[axis] * newStride[axis]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// newData[newIndex] = t.data[idx]
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// numDims := 4
|
||||||
|
// for numDims > 1 && newShape4[numDims-1] <= 1 {
|
||||||
|
// numDims--
|
||||||
|
// }
|
||||||
|
|
||||||
|
// newShape := make([]int, numDims)
|
||||||
|
// copy(newShape, newShape4[:numDims])
|
||||||
|
|
||||||
|
// return &testTensor{
|
||||||
|
// dtype: t.dtype,
|
||||||
|
// elementSize: t.elementSize,
|
||||||
|
// data: newData,
|
||||||
|
// shape: newShape,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||||
|
// dst := t
|
||||||
|
// srcTensor := src.(*testTensor)
|
||||||
|
// idxTensor := idxs.(*testTensor)
|
||||||
|
|
||||||
|
// shapeTo4D := func(shape []int) [4]int {
|
||||||
|
// out := [4]int{1, 1, 1, 1}
|
||||||
|
// for i := 0; i < len(shape) && i < 4; i++ {
|
||||||
|
// out[i] = shape[i]
|
||||||
|
// }
|
||||||
|
// return out
|
||||||
|
// }
|
||||||
|
|
||||||
|
// computeStrides := func(shape [4]int) [4]int {
|
||||||
|
// out := [4]int{1, 1, 1, 1}
|
||||||
|
// for i := 1; i < 4; i++ {
|
||||||
|
// out[i] = out[i-1] * shape[i-1]
|
||||||
|
// }
|
||||||
|
// return out
|
||||||
|
// }
|
||||||
|
|
||||||
|
// dstShape4D := shapeTo4D(dst.shape)
|
||||||
|
// srcShape4D := shapeTo4D(srcTensor.shape)
|
||||||
|
// idxShape4D := shapeTo4D(idxTensor.shape)
|
||||||
|
|
||||||
|
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
|
||||||
|
// panic("SetRows requires matching tensor shapes")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if srcShape4D[1] != idxShape4D[0] {
|
||||||
|
// panic("SetRows rows/index mismatch")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
|
||||||
|
// panic("SetRows cannot broadcast indices")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if idxShape4D[3] != 1 {
|
||||||
|
// panic("SetRows expects 1D or 2D index tensors")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// dstStride := computeStrides(dstShape4D)
|
||||||
|
// srcStride := computeStrides(srcShape4D)
|
||||||
|
// idxStride := computeStrides(idxShape4D)
|
||||||
|
|
||||||
|
// numColumns := srcShape4D[0]
|
||||||
|
// numRows := srcShape4D[1]
|
||||||
|
|
||||||
|
// for dim3Index := range dstShape4D[3] {
|
||||||
|
// for dim2Index := range dstShape4D[2] {
|
||||||
|
// idxDim2 := 0
|
||||||
|
// idxDim3 := 0
|
||||||
|
// if idxShape4D[1] > 0 {
|
||||||
|
// idxDim2 = dim2Index % idxShape4D[1]
|
||||||
|
// }
|
||||||
|
// if idxShape4D[2] > 0 {
|
||||||
|
// idxDim3 = dim3Index % idxShape4D[2]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
|
||||||
|
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
|
||||||
|
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
|
||||||
|
|
||||||
|
// for row := range numRows {
|
||||||
|
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
|
||||||
|
// if idx < 0 || idx >= dstShape4D[1] {
|
||||||
|
// panic("SetRows index out of range")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// srcOffset := srcBase + row*srcStride[1]
|
||||||
|
// dstOffset := dstBase + idx*dstStride[1]
|
||||||
|
|
||||||
|
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return dst
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
// copy(t2.(*testTensor).data, t.data)
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
156
x/kvcache/encoder.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
// import (
|
||||||
|
// "fmt"
|
||||||
|
|
||||||
|
// "github.com/ollama/ollama/ml"
|
||||||
|
// "github.com/ollama/ollama/model/input"
|
||||||
|
// )
|
||||||
|
|
||||||
|
// // Encoder cache stores K and V tensors that are position independent
|
||||||
|
// //
|
||||||
|
// // The tensors can be of any shape and will be returned as they were stored
|
||||||
|
// // The mask is currently always nil
|
||||||
|
// //
|
||||||
|
// // Not currently safe for multiple sequences
|
||||||
|
// type EncoderCache struct {
|
||||||
|
// // config controls mostly backend-specific optimizations
|
||||||
|
// config *ml.CacheConfig
|
||||||
|
|
||||||
|
// // ** current forward pass **
|
||||||
|
|
||||||
|
// // the active layer for Get and Put
|
||||||
|
// curLayer int
|
||||||
|
|
||||||
|
// // if something is stored during this pass, this
|
||||||
|
// // will be the position (but there is no guarantee
|
||||||
|
// // anything will be stored)
|
||||||
|
// curPos int32
|
||||||
|
|
||||||
|
// // curReserve indicates that this forward pass is only for
|
||||||
|
// // memory reservation and we should not update our metadata
|
||||||
|
// // based on it.
|
||||||
|
// curReserve bool
|
||||||
|
|
||||||
|
// // ** cache metadata **
|
||||||
|
|
||||||
|
// // was something stored in the cache?
|
||||||
|
// encoderCached bool
|
||||||
|
|
||||||
|
// // position of the cached data
|
||||||
|
// encoderPos int32
|
||||||
|
|
||||||
|
// // ** cache data storage **
|
||||||
|
// backend ml.Backend
|
||||||
|
// ctxs map[int]ml.Context
|
||||||
|
// keys, values map[int]ml.Tensor
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func NewEncoderCache() *EncoderCache {
|
||||||
|
// return &EncoderCache{
|
||||||
|
// ctxs: make(map[int]ml.Context),
|
||||||
|
// keys: make(map[int]ml.Tensor),
|
||||||
|
// values: make(map[int]ml.Tensor),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
// if c.config == nil {
|
||||||
|
// var config ml.CacheConfig
|
||||||
|
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
|
// config = cc.CacheConfig()
|
||||||
|
// }
|
||||||
|
// c.config = &config
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if maxSequences > 1 {
|
||||||
|
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||||
|
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.backend = backend
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||||
|
// if c.config != nil {
|
||||||
|
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.config = &config
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) Close() {
|
||||||
|
// for _, ctx := range c.ctxs {
|
||||||
|
// ctx.Close()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||||
|
// // We work with the most recent image
|
||||||
|
// if len(batch.Multimodal) > 0 {
|
||||||
|
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.curReserve = reserve
|
||||||
|
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) SetLayer(layer int) {
|
||||||
|
// c.curLayer = layer
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) EncoderCached() bool {
|
||||||
|
// return c.encoderCached
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
// return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
// if !c.curReserve {
|
||||||
|
// c.encoderPos = c.curPos
|
||||||
|
// c.encoderCached = true
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if c.config.PermutedV {
|
||||||
|
// value = value.Transpose(ctx, 1, 2, 0, 3)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
|
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, ok := c.values[c.curLayer]; !ok {
|
||||||
|
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// ctx.Forward(
|
||||||
|
// key.Copy(ctx, c.keys[c.curLayer]),
|
||||||
|
// value.Copy(ctx, c.values[c.curLayer]),
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
// panic("encoder cache does not support multiple sequences")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||||
|
// return true
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||||
|
// c.encoderCached = false
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
144
x/kvcache/mlx.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
"github.com/ollama/ollama/x/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Causal cache stores K and V tensors according to their position in the
|
||||||
|
// sequence. Returns the history and a mask for attending to past tokens
|
||||||
|
type MLXCausal struct {
|
||||||
|
DType ml.DType
|
||||||
|
|
||||||
|
// locations for data storage for this batch
|
||||||
|
curLocPut ml.Tensor
|
||||||
|
|
||||||
|
// locations for data storage for this batch
|
||||||
|
curLocGet ml.Tensor
|
||||||
|
|
||||||
|
// the active layer for Get and Put
|
||||||
|
curLayer int
|
||||||
|
|
||||||
|
capacity int
|
||||||
|
|
||||||
|
offset int
|
||||||
|
|
||||||
|
backend ml.Backend
|
||||||
|
ctxs map[int]ml.Context
|
||||||
|
keys, values map[int]ml.Tensor
|
||||||
|
|
||||||
|
// TODO is this needed per layer, or will it always be consistent?
|
||||||
|
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMLXCausalCache() *MLXCausal {
|
||||||
|
return &MLXCausal{
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
kHeadDims: make(map[int]int),
|
||||||
|
vHeadDims: make(map[int]int),
|
||||||
|
numKVHeads: make(map[int]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
c.DType = dtype
|
||||||
|
c.capacity = capacity
|
||||||
|
c.backend = backend
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
|
||||||
|
|
||||||
|
func (c *MLXCausal) SetLayer(layer int) {
|
||||||
|
c.curLayer = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MLXCausal) Close() {
|
||||||
|
// slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
|
||||||
|
for _, ctx := range c.ctxs {
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||||
|
locsPut := make([]int32, len(batch.Positions))
|
||||||
|
for i := c.offset; i < len(batch.Positions); i++ {
|
||||||
|
locsPut[i-c.offset] = int32(i)
|
||||||
|
}
|
||||||
|
c.offset += len(batch.Positions)
|
||||||
|
locsGet := make([]int32, c.offset)
|
||||||
|
for i := range c.offset {
|
||||||
|
locsGet[i] = int32(i)
|
||||||
|
}
|
||||||
|
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||||
|
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||||
|
// slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
kHeadDim := key.Dim(3)
|
||||||
|
vHeadDim := value.Dim(3)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
batchSize := key.Dim(2)
|
||||||
|
kCellSize := kHeadDim * numKVHeads
|
||||||
|
vCellSize := vHeadDim * numKVHeads
|
||||||
|
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||||
|
|
||||||
|
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
|
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||||
|
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
// slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||||
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||||
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||||
|
c.kHeadDims[c.curLayer] = kHeadDim
|
||||||
|
c.vHeadDims[c.curLayer] = vHeadDim
|
||||||
|
c.numKVHeads[c.curLayer] = numKVHeads
|
||||||
|
}
|
||||||
|
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||||
|
|
||||||
|
// slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||||
|
// slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
|
||||||
|
// slog.Info("XXX MLXCausal.Put ", "key", key)
|
||||||
|
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||||
|
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||||
|
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
key := c.keys[c.curLayer]
|
||||||
|
value := c.values[c.curLayer]
|
||||||
|
|
||||||
|
kHeadDim := c.kHeadDims[c.curLayer]
|
||||||
|
vHeadDim := c.vHeadDims[c.curLayer]
|
||||||
|
numKVHeads := c.numKVHeads[c.curLayer]
|
||||||
|
// rowSize := numKVHeads * c.curBatchSize
|
||||||
|
// cachedSize := c.curMask.Dim(1)
|
||||||
|
cachedSize := c.curLocGet.Dim(0)
|
||||||
|
// kCellSize := kHeadDim * numKVHeads
|
||||||
|
// vCellSize := vHeadDim * numKVHeads
|
||||||
|
// slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||||
|
|
||||||
|
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||||
|
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||||
|
return key, value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MLXCausal) CanResume(seq int, pos int32) bool {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
110
x/kvcache/wrapper.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
// import (
|
||||||
|
// "math"
|
||||||
|
|
||||||
|
// "github.com/ollama/ollama/ml"
|
||||||
|
// "github.com/ollama/ollama/model/input"
|
||||||
|
// )
|
||||||
|
|
||||||
|
// // Wrapper cache is a container for multiple types of caches,
|
||||||
|
// // such as for the encoding and decoding portions of a model.
|
||||||
|
// type WrapperCache struct {
|
||||||
|
// // caches we are wrapping
|
||||||
|
// caches []Cache
|
||||||
|
|
||||||
|
// // cache to be used for this layer
|
||||||
|
// curType int
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||||
|
// return &WrapperCache{
|
||||||
|
// caches: caches,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
// for _, cache := range c.caches {
|
||||||
|
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||||
|
// for _, cache := range c.caches {
|
||||||
|
// cache.SetConfig(config)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) Close() {
|
||||||
|
// for _, cache := range c.caches {
|
||||||
|
// cache.Close()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||||
|
// for i, cache := range c.caches {
|
||||||
|
// err := cache.StartForward(ctx, batch, reserve)
|
||||||
|
// if err != nil {
|
||||||
|
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||||
|
// for j := i - 1; j >= 0; j-- {
|
||||||
|
// for k := range batch.Positions {
|
||||||
|
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.curType = 0
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) SetLayer(layer int) {
|
||||||
|
// for _, cache := range c.caches {
|
||||||
|
// cache.SetLayer(layer)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) SetLayerType(layerType int) {
|
||||||
|
// c.curType = layerType
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) UnderlyingCache() Cache {
|
||||||
|
// return c.caches[c.curType]
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
// return c.caches[c.curType].Get(ctx)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
// c.caches[c.curType].Put(ctx, key, value)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
// for _, cache := range c.caches {
|
||||||
|
// cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||||
|
// for _, cache := range c.caches {
|
||||||
|
// if !cache.CanResume(seq, pos) {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return true
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||||
|
// for _, cache := range c.caches {
|
||||||
|
// err := cache.Remove(seq, beginIndex, endIndex)
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
433
x/ml/backend.go
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
package ml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Backend interface {
|
||||||
|
// Close frees all memory associated with this backend
|
||||||
|
// Close()
|
||||||
|
|
||||||
|
// Load(ctx context.Context, progress func(float32)) error
|
||||||
|
|
||||||
|
// BackendMemory returns the memory allocations that were made for this model
|
||||||
|
// BackendMemory() BackendMemory
|
||||||
|
|
||||||
|
Config() fs.Config
|
||||||
|
Get(name string) Tensor
|
||||||
|
NewContext() Context
|
||||||
|
// NewContextSize(size int) Context
|
||||||
|
|
||||||
|
// Enumerate the devices available for inference via this backend
|
||||||
|
// BackendDevices() []DeviceInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackendCacheConfig should be implemented by backends that need special output
|
||||||
|
// from the cache to meet specific requirements. It is frequently implemented in
|
||||||
|
// conjunction with ScaledDotProductAttention.
|
||||||
|
type BackendCacheConfig interface {
|
||||||
|
CacheConfig() CacheConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||||
|
// the output the cache to work better with specific kernels.
|
||||||
|
type CacheConfig struct {
|
||||||
|
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||||
|
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||||
|
// cache itself will also be increased to a multiple of this size if needed.
|
||||||
|
CachePadding int
|
||||||
|
|
||||||
|
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||||
|
// and return the permuted version via Get. This uses the cache copy operation
|
||||||
|
// to avoid a Contiguous call on the permuted tensor.
|
||||||
|
PermutedV bool
|
||||||
|
|
||||||
|
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||||
|
// default to DTypeF32.
|
||||||
|
MaskDType DType
|
||||||
|
|
||||||
|
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||||
|
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||||
|
MaskBatchPadding int
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackendParams controls how the backend loads and executes models
|
||||||
|
type BackendParams struct {
|
||||||
|
// AllocMemory causes the backend to allocate memory for the model. If
|
||||||
|
// false, this is only being used for discovering the required amount of
|
||||||
|
// memory and cannot load the model for running.
|
||||||
|
AllocMemory bool
|
||||||
|
|
||||||
|
// NumThreads sets the number of threads to use if running on the CPU
|
||||||
|
NumThreads int
|
||||||
|
|
||||||
|
// GPULayers is the set of layers to offload to GPUs
|
||||||
|
GPULayers GPULayersList
|
||||||
|
|
||||||
|
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||||
|
FlashAttention bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||||
|
|
||||||
|
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||||
|
if _, ok := backends[name]; ok {
|
||||||
|
panic("backend: backend already registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
backends[name] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||||
|
be := os.Getenv("OLLAMA_BACKEND")
|
||||||
|
if be == "" {
|
||||||
|
be = "mlx"
|
||||||
|
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
|
||||||
|
}
|
||||||
|
slog.Info("Loading new engine", "backend", be)
|
||||||
|
if backend, ok := backends[be]; ok {
|
||||||
|
return backend(modelPath, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
type Context interface {
|
||||||
|
Empty(dtype DType, shape ...int) Tensor
|
||||||
|
Zeros(dtype DType, shape ...int) Tensor
|
||||||
|
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
|
||||||
|
FromFloats(s []float32, shape ...int) Tensor
|
||||||
|
FromInts(s []int32, shape ...int) Tensor
|
||||||
|
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
|
||||||
|
|
||||||
|
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
||||||
|
Arange(start, stop, step float32, dtype DType) Tensor
|
||||||
|
|
||||||
|
Forward(...Tensor) Context
|
||||||
|
|
||||||
|
// SetBatchSize provides a hint on the batch size to optimize processing
|
||||||
|
// Uses heuristics if not set
|
||||||
|
// SetBatchSize(int)
|
||||||
|
|
||||||
|
Compute(...Tensor)
|
||||||
|
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||||
|
|
||||||
|
// Reserve is analogous to Compute but rather than executing a
|
||||||
|
// graph, simply preallocates memory. Typically called with a
|
||||||
|
// worst case graph to ensure all resources are available for
|
||||||
|
// for future inference.
|
||||||
|
// Reserve()
|
||||||
|
|
||||||
|
// MaxGraphNodes() int
|
||||||
|
Close()
|
||||||
|
|
||||||
|
// Input returns a context appropriate for creating tensors that are
|
||||||
|
// inputs to the model (which includes things like output locations)
|
||||||
|
Input() Context
|
||||||
|
|
||||||
|
// Layer returns a context appropriate for creating intermediate tensors
|
||||||
|
Layer(int) Context
|
||||||
|
|
||||||
|
// Load a tensor from "filename" safetensors file, and compare with the input tensor
|
||||||
|
// Returns error if the shape is inconsistent, or similarity measures are below 99%
|
||||||
|
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type RoPEOptions struct {
|
||||||
|
Base *float32
|
||||||
|
Freqs Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRoPEBase(base float32) func(*RoPEOptions) {
|
||||||
|
return func(opts *RoPEOptions) {
|
||||||
|
opts.Base = &base
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
|
||||||
|
return func(opts *RoPEOptions) {
|
||||||
|
opts.Freqs = freqs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tensor interface {
|
||||||
|
ToString() string
|
||||||
|
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
|
||||||
|
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
|
||||||
|
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
|
||||||
|
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
|
||||||
|
|
||||||
|
Dim(n int) int
|
||||||
|
Stride(n int) int
|
||||||
|
|
||||||
|
Shape() []int
|
||||||
|
DType() DType
|
||||||
|
// Cast(ctx Context, dtype DType) Tensor
|
||||||
|
|
||||||
|
// Bytes() []byte
|
||||||
|
Floats() []float32
|
||||||
|
Ints() []int32
|
||||||
|
|
||||||
|
// FromBytes([]byte)
|
||||||
|
// FromFloats([]float32)
|
||||||
|
// FromInts([]int32)
|
||||||
|
|
||||||
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
|
Sub(ctx Context, t2 Tensor) Tensor
|
||||||
|
// Mul(ctx Context, t2 Tensor) Tensor
|
||||||
|
// Div(ctx Context, t2 Tensor) Tensor
|
||||||
|
|
||||||
|
Max(ctx Context, axes []int, keepDims bool) Tensor
|
||||||
|
Min(ctx Context, axes []int, keepDims bool) Tensor
|
||||||
|
|
||||||
|
Matmul(ctx Context, a2 Tensor) Tensor
|
||||||
|
// Mulmat(ctx Context, t2 Tensor) Tensor
|
||||||
|
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||||
|
// MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||||
|
// AddID(ctx Context, t2, ids Tensor) Tensor
|
||||||
|
|
||||||
|
Softmax(ctx Context) Tensor
|
||||||
|
L2Norm(ctx Context, eps float32) Tensor
|
||||||
|
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||||
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||||
|
Scale(ctx Context, s float64) Tensor
|
||||||
|
// SumRows(ctx Context) Tensor
|
||||||
|
|
||||||
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||||
|
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
|
||||||
|
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
|
||||||
|
|
||||||
|
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
|
// Sin(ctx Context) Tensor
|
||||||
|
// Cos(ctx Context) Tensor
|
||||||
|
// Tanh(ctx Context) Tensor
|
||||||
|
GELU(ctx Context, up ...Tensor) Tensor
|
||||||
|
// QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||||
|
// SILU(ctx Context, up ...Tensor) Tensor
|
||||||
|
// RELU(ctx Context, up ...Tensor) Tensor
|
||||||
|
// Sigmoid(ctx Context) Tensor
|
||||||
|
|
||||||
|
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||||
|
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||||
|
|
||||||
|
Reshape(ctx Context, shape ...int) Tensor
|
||||||
|
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
|
||||||
|
Transpose(ctx Context, shape ...int) Tensor
|
||||||
|
Contiguous(ctx Context, allowColMajor bool) Tensor
|
||||||
|
|
||||||
|
// Pad(ctx Context, shape ...int) Tensor
|
||||||
|
|
||||||
|
// Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||||
|
|
||||||
|
// Repeat repeats the tensor n times along dimension dim
|
||||||
|
// Repeat(ctx Context, dim, n int) Tensor
|
||||||
|
// Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||||
|
// Rows(ctx Context, t2 Tensor) Tensor
|
||||||
|
|
||||||
|
// TODO these probably aren't actually needed - false starts on trying to wire up cache
|
||||||
|
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
|
||||||
|
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
|
||||||
|
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
|
||||||
|
|
||||||
|
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
|
||||||
|
|
||||||
|
Copy(ctx Context, t2 Tensor) Tensor
|
||||||
|
// Duplicate(ctx Context) Tensor
|
||||||
|
|
||||||
|
// Slice(ctx Context, dim, low, high, step int) Tensor
|
||||||
|
// Chunk(ctx Context, dim int, size int) []Tensor
|
||||||
|
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||||
|
|
||||||
|
// TopK(ctx Context, k int) Tensor
|
||||||
|
// Argsort(ctx Context) Tensor
|
||||||
|
// Mean(ctx Context) Tensor
|
||||||
|
// Variance(ctx Context) Tensor
|
||||||
|
// Stddev(ctx Context) Tensor
|
||||||
|
// Sqr(ctx Context) Tensor
|
||||||
|
// Sqrt(ctx Context) Tensor
|
||||||
|
|
||||||
|
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScaledDotProductAttention implements a fused attention
|
||||||
|
// operation equivalent to following code on a tensor named
|
||||||
|
// query:
|
||||||
|
//
|
||||||
|
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
//
|
||||||
|
// kq := key.MulmatFullPrec(ctx, query)
|
||||||
|
//
|
||||||
|
// kq = kq.Scale(ctx, scale)
|
||||||
|
//
|
||||||
|
// if mask != nil {
|
||||||
|
// kq = kq.Add(ctx, mask)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// kq = kq.Softmax(ctx)
|
||||||
|
//
|
||||||
|
// kqv := value.Mulmat(ctx, kq)
|
||||||
|
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
// type ScaledDotProductAttention interface {
|
||||||
|
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type number interface {
|
||||||
|
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||||
|
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||||
|
// ~float32 | ~float64 |
|
||||||
|
// ~complex64 | ~complex128
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func mul[T number](s ...T) T {
|
||||||
|
// p := T(1)
|
||||||
|
// for _, v := range s {
|
||||||
|
// p *= v
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return p
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type DumpOptions func(*dumpOptions)
|
||||||
|
|
||||||
|
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
||||||
|
// func DumpWithPrecision(n int) DumpOptions {
|
||||||
|
// return func(opts *dumpOptions) {
|
||||||
|
// opts.Precision = n
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
||||||
|
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
||||||
|
// // beginning and end of each dimension will be printed.
|
||||||
|
// func DumpWithThreshold(n int) DumpOptions {
|
||||||
|
// return func(opts *dumpOptions) {
|
||||||
|
// opts.Threshold = n
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
||||||
|
// func DumpWithEdgeItems(n int) DumpOptions {
|
||||||
|
// return func(opts *dumpOptions) {
|
||||||
|
// opts.EdgeItems = n
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type dumpOptions struct {
|
||||||
|
// Precision, Threshold, EdgeItems int
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
||||||
|
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
||||||
|
// for _, optsFunc := range optsFuncs {
|
||||||
|
// optsFunc(&opts)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if mul(t.Shape()...) <= opts.Threshold {
|
||||||
|
// opts.EdgeItems = math.MaxInt
|
||||||
|
// }
|
||||||
|
|
||||||
|
// switch t.DType() {
|
||||||
|
// case DTypeFloat32:
|
||||||
|
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
||||||
|
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||||
|
// })
|
||||||
|
// case DTypeFloat16: // TODO other types...
|
||||||
|
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
|
||||||
|
// f32 = t.Copy(ctx, f32)
|
||||||
|
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
||||||
|
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||||
|
// })
|
||||||
|
// case DTypeInt32:
|
||||||
|
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
||||||
|
// return strconv.FormatInt(int64(i), 10)
|
||||||
|
// })
|
||||||
|
// default:
|
||||||
|
// return "<unsupported>"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||||
|
// if t.Bytes() == nil {
|
||||||
|
// ctx.Compute(t)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// s := make(S, mul(t.Shape()...))
|
||||||
|
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// shape := t.Shape()
|
||||||
|
// slices.Reverse(shape)
|
||||||
|
|
||||||
|
// var sb strings.Builder
|
||||||
|
// var f func([]int, int)
|
||||||
|
// f = func(dims []int, stride int) {
|
||||||
|
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||||
|
// sb.WriteString("[")
|
||||||
|
// defer func() { sb.WriteString("]") }()
|
||||||
|
// for i := 0; i < dims[0]; i++ {
|
||||||
|
// if i >= items && i < dims[0]-items {
|
||||||
|
// sb.WriteString("..., ")
|
||||||
|
// // skip to next printable element
|
||||||
|
// skip := dims[0] - 2*items
|
||||||
|
// if len(dims) > 1 {
|
||||||
|
// stride += mul(append(dims[1:], skip)...)
|
||||||
|
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||||
|
// }
|
||||||
|
// i += skip - 1
|
||||||
|
// } else if len(dims) > 1 {
|
||||||
|
// f(dims[1:], stride)
|
||||||
|
// stride += mul(dims[1:]...)
|
||||||
|
// if i < dims[0]-1 {
|
||||||
|
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||||
|
// }
|
||||||
|
// } else {
|
||||||
|
// text := fn(s[stride+i])
|
||||||
|
// if len(text) > 0 && text[0] != '-' {
|
||||||
|
// sb.WriteString(" ")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// sb.WriteString(text)
|
||||||
|
// if i < dims[0]-1 {
|
||||||
|
// sb.WriteString(", ")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// f(shape, 0)
|
||||||
|
|
||||||
|
// return sb.String()
|
||||||
|
// }
|
||||||
|
|
||||||
|
type DType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
DTypeBool DType = iota
|
||||||
|
DTypeUint8
|
||||||
|
DTypeUint16
|
||||||
|
DTypeUint32
|
||||||
|
DTypeUint64
|
||||||
|
DTypeInt8
|
||||||
|
DTypeInt16
|
||||||
|
DTypeInt32
|
||||||
|
DTypeInt64
|
||||||
|
DTypeFloat16
|
||||||
|
DTypeFloat32
|
||||||
|
DTypeFloat64
|
||||||
|
DTypeBfloat16
|
||||||
|
DTypeComplex64
|
||||||
|
)
|
||||||
|
|
||||||
|
type SamplingMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
SamplingModeNearest SamplingMode = iota
|
||||||
|
SamplingModeBilinear
|
||||||
|
)
|
||||||
3
x/ml/backend/backend.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
// _ "github.com/ollama/ollama/x/ml/backend/mlx"
|
||||||
57
x/ml/backend/mlx/CMakeLists.txt
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||||
|
|
||||||
|
set(MLX_BUILD_GGUF OFF)
|
||||||
|
set(MLX_BUILD_SAFETENSORS ON)
|
||||||
|
|
||||||
|
function(set_target_output_directory _target)
|
||||||
|
if(TARGET ${_target})
|
||||||
|
set_target_properties(${_target} PROPERTIES
|
||||||
|
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||||
|
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
# Check for Metal support (macOS only)
|
||||||
|
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||||
|
execute_process(
|
||||||
|
COMMAND
|
||||||
|
zsh "-c"
|
||||||
|
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
|
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
|
if(NOT MLX_METAL_VERSION)
|
||||||
|
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
|
||||||
|
set(MLX_BUILD_METAL OFF)
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
# On Linux, disable Metal backend
|
||||||
|
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
|
||||||
|
set(MLX_BUILD_METAL OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
|
||||||
|
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
|
||||||
|
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
|
||||||
|
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
|
||||||
|
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
|
||||||
|
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
|
||||||
|
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
|
elseif(MLX_CUDA_ARCHITECTURES)
|
||||||
|
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
FetchContent_Declare(
|
||||||
|
mlx-c
|
||||||
|
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||||
|
GIT_TAG v0.4.1)
|
||||||
|
FetchContent_MakeAvailable(mlx-c)
|
||||||
|
|
||||||
|
set_target_output_directory(mlx)
|
||||||
|
set_target_output_directory(mlxc)
|
||||||
1278
x/ml/backend/mlx/mlx.go
Normal file
314
x/ml/backend/mlx/mlx_test.go
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package mlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/runner/common"
|
||||||
|
"github.com/ollama/ollama/sample"
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
"github.com/ollama/ollama/x/model"
|
||||||
|
"github.com/ollama/ollama/x/model/input"
|
||||||
|
_ "github.com/ollama/ollama/x/model/models/gemma3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||||
|
slog.SetDefault(logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadModel(t *testing.T) {
|
||||||
|
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||||
|
b := &Backend{}
|
||||||
|
err := b.LoadSafeTensors(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load failed: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromInts(t *testing.T) {
|
||||||
|
b := &Backend{}
|
||||||
|
c := b.NewContext()
|
||||||
|
defer c.Close()
|
||||||
|
data := []int32{1, 2, 3, 4, 5, 6}
|
||||||
|
a := c.FromInts(data, 2, 3)
|
||||||
|
slog.Info("", "array", a)
|
||||||
|
t.Log(a.ToString())
|
||||||
|
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||||
|
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromFloats(t *testing.T) {
|
||||||
|
b := &Backend{}
|
||||||
|
c := b.NewContext()
|
||||||
|
defer c.Close()
|
||||||
|
data := []float32{1, 2, 3, 4, 5, 6}
|
||||||
|
a := c.FromFloats(data, 2, 3)
|
||||||
|
slog.Info("", "array", a)
|
||||||
|
t.Log(a.ToString())
|
||||||
|
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||||
|
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||||
|
}
|
||||||
|
res := a.Floats()
|
||||||
|
if !reflect.DeepEqual(res, data) {
|
||||||
|
t.Fatalf("incorrect results: %v", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdd(t *testing.T) {
|
||||||
|
b := &Backend{}
|
||||||
|
c := b.NewContext()
|
||||||
|
defer c.Close()
|
||||||
|
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||||
|
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||||
|
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
|
||||||
|
t3 := t1.Add(c, t2)
|
||||||
|
c.Compute(t3, exp)
|
||||||
|
t3f := t3.Floats()
|
||||||
|
if !reflect.DeepEqual(t3f, exp.Floats()) {
|
||||||
|
t.Fatalf("incorrect result: %v", t3f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReshapeTranspose(t *testing.T) {
|
||||||
|
b := &Backend{}
|
||||||
|
c := b.NewContext()
|
||||||
|
defer c.Close()
|
||||||
|
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
|
||||||
|
c.Compute(t1)
|
||||||
|
t1f := t1.Floats()
|
||||||
|
exp := []float32{
|
||||||
|
0, 4, 8,
|
||||||
|
1, 5, 9,
|
||||||
|
2, 6, 10,
|
||||||
|
3, 7, 11,
|
||||||
|
12, 16, 20,
|
||||||
|
13, 17, 21,
|
||||||
|
14, 18, 22,
|
||||||
|
15, 19, 23,
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(t1f, exp) {
|
||||||
|
t.Fatalf("incorrect results: %v", t1f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func prod(vals ...int) int {
|
||||||
|
r := 1
|
||||||
|
for _, v := range vals {
|
||||||
|
r *= v
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
func TestMatmul(t *testing.T) {
|
||||||
|
// TODO create scenarios...
|
||||||
|
b := &Backend{}
|
||||||
|
c := b.NewContext()
|
||||||
|
defer c.Close()
|
||||||
|
s1 := []int{1, 3, 2, 4}
|
||||||
|
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
|
||||||
|
s2 := []int{4, 2}
|
||||||
|
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
|
||||||
|
t3 := t1.Matmul(c, t2)
|
||||||
|
exp := []float32{
|
||||||
|
28, 34,
|
||||||
|
76, 98,
|
||||||
|
|
||||||
|
124, 162,
|
||||||
|
172, 226,
|
||||||
|
|
||||||
|
220, 290,
|
||||||
|
268, 354,
|
||||||
|
}
|
||||||
|
c.Compute(t3)
|
||||||
|
t3f := t3.Floats()
|
||||||
|
if !reflect.DeepEqual(t3f, exp) {
|
||||||
|
t.Fatalf("incorrect result: %v", t3f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRows(t *testing.T) {
|
||||||
|
b := &Backend{}
|
||||||
|
c := b.NewContext()
|
||||||
|
defer c.Close()
|
||||||
|
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
|
||||||
|
outputs := c.Zeros(ml.DTypeInt32, 1)
|
||||||
|
t2 := t1.TakeAxes(c, outputs, 1)
|
||||||
|
c.Forward(t1, t2).Compute(t1, t2)
|
||||||
|
t.Log(t1.ToString())
|
||||||
|
t.Log(t2.ToString())
|
||||||
|
f := t2.Floats()
|
||||||
|
t.Logf("Result: %v", f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCaching(t *testing.T) {
|
||||||
|
// Validate the caching algorithm
|
||||||
|
b := &Backend{}
|
||||||
|
c := b.NewContext()
|
||||||
|
defer c.Close()
|
||||||
|
batchSize := 3
|
||||||
|
headDim := 4
|
||||||
|
numKVHeads := 2
|
||||||
|
// Make cache twice the size of one test batch
|
||||||
|
cells := batchSize * 2
|
||||||
|
cellSize := numKVHeads * headDim
|
||||||
|
shape := []int{1, numKVHeads, batchSize, headDim}
|
||||||
|
stop := float32(1)
|
||||||
|
for _, x := range shape {
|
||||||
|
stop *= float32(x)
|
||||||
|
}
|
||||||
|
// Create the cache
|
||||||
|
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
|
||||||
|
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
|
||||||
|
|
||||||
|
// Input tensor
|
||||||
|
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
|
||||||
|
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
|
||||||
|
|
||||||
|
// Reshape to copy into the cache
|
||||||
|
/*
|
||||||
|
From MLX python/src/indexing.cpp mlx_scatter_args_array
|
||||||
|
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||||
|
auto up_shape = indices.shape();
|
||||||
|
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||||
|
up = broadcast_to(up, up_shape);
|
||||||
|
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||||
|
up = reshape(up, up_shape);
|
||||||
|
*/
|
||||||
|
numRows := 3
|
||||||
|
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
|
||||||
|
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
|
||||||
|
|
||||||
|
// Simulate cells 1,3,5 are available
|
||||||
|
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
|
||||||
|
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
|
||||||
|
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
|
||||||
|
cache.Scatter(c, indicies, up, axis)
|
||||||
|
|
||||||
|
c.Forward(cache)
|
||||||
|
// Cache should contain the data now
|
||||||
|
t.Log("Cache after put\n" + cache.ToString())
|
||||||
|
|
||||||
|
// Retrieve cache content and verify it matches
|
||||||
|
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
|
||||||
|
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
|
||||||
|
|
||||||
|
t1f := t1.Floats()
|
||||||
|
outf := out.Floats()
|
||||||
|
if !reflect.DeepEqual(t1f, outf) {
|
||||||
|
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGemma3(t *testing.T) {
|
||||||
|
// Why is the sky blue
|
||||||
|
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
|
||||||
|
limit := 50
|
||||||
|
|
||||||
|
// TODO generalize this
|
||||||
|
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||||
|
|
||||||
|
m, err := model.New(dir, ml.BackendParams{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to load model: %s", err)
|
||||||
|
}
|
||||||
|
b := m.Backend()
|
||||||
|
ctx := b.NewContext()
|
||||||
|
defer ctx.Close()
|
||||||
|
|
||||||
|
batch := input.Batch{
|
||||||
|
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
|
||||||
|
Positions: make([]int32, len(inputs)),
|
||||||
|
Sequences: make([]int, len(inputs)),
|
||||||
|
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
|
||||||
|
Offset: 0,
|
||||||
|
}
|
||||||
|
for i := range len(inputs) {
|
||||||
|
batch.Positions[i] = int32(i)
|
||||||
|
}
|
||||||
|
offset := len(inputs)
|
||||||
|
|
||||||
|
cache := m.Config().Cache
|
||||||
|
if cache != nil {
|
||||||
|
numSlots := 1
|
||||||
|
batchSize := 512
|
||||||
|
numCtx := 4096
|
||||||
|
|
||||||
|
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
|
||||||
|
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
|
||||||
|
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
|
||||||
|
|
||||||
|
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
|
||||||
|
err := cache.StartForward(ctx, batch, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed cache.StartForward: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
var grammar *sample.GrammarSampler
|
||||||
|
sampler := sample.NewSampler(
|
||||||
|
opts.Temperature,
|
||||||
|
opts.TopK,
|
||||||
|
opts.TopP,
|
||||||
|
opts.MinP,
|
||||||
|
opts.Seed,
|
||||||
|
grammar,
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Log("Starting Forward pass loop")
|
||||||
|
pendingResponses := []string{}
|
||||||
|
for {
|
||||||
|
out, err := m.Forward(ctx, batch)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed forward pass: %s", err)
|
||||||
|
}
|
||||||
|
ctx.Forward(out)
|
||||||
|
outputs := out.Floats()
|
||||||
|
t.Logf("finished forward pass! length:%d", len(outputs))
|
||||||
|
// sample a token
|
||||||
|
logits := outputs
|
||||||
|
token, err := sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to sample token: %s", err)
|
||||||
|
}
|
||||||
|
t.Logf("Sampled token: %v", token)
|
||||||
|
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||||
|
t.Log("hit EOS")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
piece, err := m.(model.TextProcessor).Decode([]int32{token})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to decode token: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingResponses = append(pendingResponses, piece)
|
||||||
|
sequence := strings.Join(pendingResponses, "")
|
||||||
|
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
|
||||||
|
t.Logf("hit stop token: %v", stop)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t.Logf("RESULTS: %s", sequence)
|
||||||
|
batch = input.Batch{
|
||||||
|
Inputs: ctx.FromInts([]int32{token}, 1, 1),
|
||||||
|
Positions: make([]int32, 1),
|
||||||
|
Sequences: make([]int, 1),
|
||||||
|
Outputs: ctx.FromInts([]int32{0}, 1),
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
offset++
|
||||||
|
batch.Positions[0] = 0
|
||||||
|
err = cache.StartForward(ctx, batch, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed cache.StartForward: %s", err)
|
||||||
|
}
|
||||||
|
if offset > limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
335
x/ml/backend/mlx/quant.go
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package mlx
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "mlx/c/array.h"
|
||||||
|
#include "mlx/c/ops.h"
|
||||||
|
|
||||||
|
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
|
||||||
|
|
||||||
|
void unpack_32_4(uint8_t* data, int8_t* dst) {
|
||||||
|
memset(dst, 0, 16);
|
||||||
|
for (int j = 0; j < 16; ++j) {
|
||||||
|
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
|
||||||
|
if (j % 2 != 0) {
|
||||||
|
x <<= 4;
|
||||||
|
}
|
||||||
|
dst[j / 2] += x;
|
||||||
|
}
|
||||||
|
// Last 16 weights are in the higher bits
|
||||||
|
for (int j = 0; j < 16; ++j) {
|
||||||
|
uint8_t x = (data[j + 2] >> 4);
|
||||||
|
if (j % 2 != 0) {
|
||||||
|
x <<= 4;
|
||||||
|
}
|
||||||
|
dst[8 + j / 2] += x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts (weight, scales, biases) from Q4_0 tensors.
|
||||||
|
// Data layout is: |16 bit scale|32 x 4bit weights|.
|
||||||
|
void extract_q4_0_data(
|
||||||
|
uint8_t* data,
|
||||||
|
mlx_array* weights_arr,
|
||||||
|
mlx_array* scales_arr,
|
||||||
|
mlx_array* biases_arr) {
|
||||||
|
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
|
||||||
|
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||||
|
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||||
|
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||||
|
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||||
|
scales[i] = *((float16_t*)data);
|
||||||
|
biases[i] = -8 * scales[i];
|
||||||
|
unpack_32_4(data, weights);
|
||||||
|
weights += 16;
|
||||||
|
data += bytes_per_block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts (weight, scales, biases) from Q4_1 tensors.
|
||||||
|
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
|
||||||
|
void extract_q4_1_data(
|
||||||
|
uint8_t* data,
|
||||||
|
mlx_array* weights_arr,
|
||||||
|
mlx_array* scales_arr,
|
||||||
|
mlx_array* biases_arr) {
|
||||||
|
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
|
||||||
|
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||||
|
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||||
|
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||||
|
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||||
|
scales[i] = *((float16_t*)data);
|
||||||
|
biases[i] = *((float16_t*)(data) + 1);
|
||||||
|
unpack_32_4(data, weights);
|
||||||
|
weights += 16;
|
||||||
|
data += bytes_per_block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
||||||
|
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
||||||
|
void extract_q8_0_data(
|
||||||
|
uint8_t* data,
|
||||||
|
mlx_array* weights_arr,
|
||||||
|
mlx_array* scales_arr,
|
||||||
|
mlx_array* biases_arr) {
|
||||||
|
const uint64_t weights_per_block = 32;
|
||||||
|
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
|
||||||
|
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||||
|
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||||
|
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||||
|
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||||
|
uint8_t* block_data = data + i * bytes_per_block;
|
||||||
|
scales[i] = *((float16_t*)block_data);
|
||||||
|
biases[i] = -128 * scales[i];
|
||||||
|
for (int64_t j = 0; j < weights_per_block; ++j) {
|
||||||
|
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
|
||||||
|
// Original data is in int8_t, so we add a bias of -128 and invert the
|
||||||
|
// first bit.
|
||||||
|
x ^= 1 << 7;
|
||||||
|
weights[i * weights_per_block + j] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drived from ggml-quants.c
|
||||||
|
|
||||||
|
#define QK_K 256
|
||||||
|
|
||||||
|
// 6-bit quantization
|
||||||
|
// weight is represented as x = a * q
|
||||||
|
// 16 blocks of 16 elements each
|
||||||
|
// Effectively 6.5625 bits per weight
|
||||||
|
typedef struct {
|
||||||
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||||
|
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||||
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||||
|
uint16_t d; // super-block scale
|
||||||
|
} block_q6_K;
|
||||||
|
|
||||||
|
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
|
||||||
|
const int64_t nb = k / QK_K;
|
||||||
|
block_q6_K *x = (block_q6_K *)vx;
|
||||||
|
float16_t* y = (float16_t *)vy;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
float16_t d = 0.0;
|
||||||
|
memcpy(&d, &x[i].d, sizeof(d));
|
||||||
|
|
||||||
|
const uint8_t * restrict ql = x[i].ql;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const int8_t * restrict sc = x[i].scales;
|
||||||
|
|
||||||
|
for (int n = 0; n < QK_K; n += 128) {
|
||||||
|
for (int l = 0; l < 32; ++l) {
|
||||||
|
int is = l/16;
|
||||||
|
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||||||
|
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||||||
|
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||||||
|
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||||||
|
y[l + 0] = d * sc[is + 0] * q1;
|
||||||
|
y[l + 32] = d * sc[is + 2] * q2;
|
||||||
|
y[l + 64] = d * sc[is + 4] * q3;
|
||||||
|
y[l + 96] = d * sc[is + 6] * q4;
|
||||||
|
}
|
||||||
|
y += 128;
|
||||||
|
ql += 64;
|
||||||
|
qh += 32;
|
||||||
|
sc += 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define K_SCALE_SIZE 12
|
||||||
|
#define GGML_COMMON_AGGR_U
|
||||||
|
#define GGML_COMMON_AGGR_S
|
||||||
|
|
||||||
|
// 4-bit quantization
|
||||||
|
// 8 blocks of 32 elements each
|
||||||
|
// weight is represented as x = a * q + b
|
||||||
|
// Effectively 4.5 bits per weight
|
||||||
|
typedef struct {
|
||||||
|
union {
|
||||||
|
struct {
|
||||||
|
uint16_t d; // super-block scale for quantized scales
|
||||||
|
uint16_t dmin; // super-block scale for quantized mins
|
||||||
|
} GGML_COMMON_AGGR_S;
|
||||||
|
uint16_t dm;
|
||||||
|
} GGML_COMMON_AGGR_U;
|
||||||
|
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||||
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
|
} block_q4_K;
|
||||||
|
|
||||||
|
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
||||||
|
if (j < 4) {
|
||||||
|
*d = q[j] & 63; *m = q[j + 4] & 63;
|
||||||
|
} else {
|
||||||
|
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||||
|
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
|
||||||
|
block_q4_K *x = (block_q4_K *)vx;
|
||||||
|
float16_t* y = (float16_t *)vy;
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
const uint8_t * q = x[i].qs;
|
||||||
|
float16_t d = 0.0;
|
||||||
|
memcpy(&d, &x[i].d, sizeof(d));
|
||||||
|
float16_t min = 0.0;
|
||||||
|
memcpy(&min, &x[i].dmin, sizeof(d));
|
||||||
|
|
||||||
|
int is = 0;
|
||||||
|
uint8_t sc, m;
|
||||||
|
for (int j = 0; j < QK_K; j += 64) {
|
||||||
|
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
||||||
|
const float16_t d1 = d * sc; const float16_t m1 = min * m;
|
||||||
|
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
||||||
|
const float16_t d2 = d * sc; const float16_t m2 = min * m;
|
||||||
|
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
||||||
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
||||||
|
q += 32; is += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/x448/float16"
|
||||||
|
)
|
||||||
|
|
||||||
|
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||||
|
shape := append([]C.int{}, final_shape...)
|
||||||
|
var weights_per_byte C.int
|
||||||
|
if dtype == 2 || dtype == 3 {
|
||||||
|
weights_per_byte = 2
|
||||||
|
} else if dtype == 8 {
|
||||||
|
weights_per_byte = 1
|
||||||
|
} else {
|
||||||
|
return r, fmt.Errorf("unsupported tensor type %d", dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
weights_per_block := C.int(32)
|
||||||
|
if shape[len(shape)-1]%weights_per_block != 0 {
|
||||||
|
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
weights_shape := append([]C.int{}, shape...)
|
||||||
|
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
|
||||||
|
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
|
||||||
|
for i := range weights_shape {
|
||||||
|
w_nbytes *= weights_shape[i]
|
||||||
|
}
|
||||||
|
w_data := make([]byte, w_nbytes)
|
||||||
|
cbytes := C.CBytes(w_data)
|
||||||
|
defer C.free(cbytes)
|
||||||
|
weights := C.mlx_array_new_data(
|
||||||
|
cbytes,
|
||||||
|
&weights_shape[0],
|
||||||
|
C.int(len(weights_shape)),
|
||||||
|
C.MLX_UINT32,
|
||||||
|
)
|
||||||
|
|
||||||
|
// For scales and bias
|
||||||
|
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
|
||||||
|
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
|
||||||
|
for i := range shape {
|
||||||
|
sb_nbytes *= shape[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
s_data := make([]byte, sb_nbytes)
|
||||||
|
cbytes = C.CBytes(s_data)
|
||||||
|
defer C.free(cbytes)
|
||||||
|
scales := C.mlx_array_new_data(
|
||||||
|
cbytes,
|
||||||
|
&shape[0],
|
||||||
|
C.int(len(shape)),
|
||||||
|
C.MLX_FLOAT16,
|
||||||
|
)
|
||||||
|
b_data := make([]byte, sb_nbytes)
|
||||||
|
cbytes = C.CBytes(b_data)
|
||||||
|
defer C.free(cbytes)
|
||||||
|
biases := C.mlx_array_new_data(
|
||||||
|
cbytes,
|
||||||
|
&shape[0],
|
||||||
|
C.int(len(shape)),
|
||||||
|
C.MLX_FLOAT16,
|
||||||
|
)
|
||||||
|
var bits C.int
|
||||||
|
switch dtype {
|
||||||
|
case 2:
|
||||||
|
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||||
|
bits = 4
|
||||||
|
case 3:
|
||||||
|
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||||
|
bits = 4
|
||||||
|
case 8:
|
||||||
|
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||||
|
bits = 8
|
||||||
|
}
|
||||||
|
groupSize := C.mlx_optional_int{value: 32, has_value: true}
|
||||||
|
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
|
||||||
|
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
|
||||||
|
C.mlx_dequantize(
|
||||||
|
&r,
|
||||||
|
weights,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
groupSize,
|
||||||
|
bitsOpt,
|
||||||
|
nil, // TODO mode
|
||||||
|
dtypeOpt,
|
||||||
|
stream,
|
||||||
|
)
|
||||||
|
C.mlx_array_free(weights)
|
||||||
|
C.mlx_array_free(scales)
|
||||||
|
C.mlx_array_free(biases)
|
||||||
|
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||||
|
size := 1
|
||||||
|
for _, d := range shape {
|
||||||
|
size *= int(d)
|
||||||
|
}
|
||||||
|
fdata := make([]float16.Float16, size)
|
||||||
|
switch dtype {
|
||||||
|
case 14:
|
||||||
|
C.dequant_row_q6_K(
|
||||||
|
data,
|
||||||
|
unsafe.Pointer(&fdata[0]),
|
||||||
|
C.int(size),
|
||||||
|
)
|
||||||
|
|
||||||
|
case 12:
|
||||||
|
C.dequant_row_q4_K(
|
||||||
|
data,
|
||||||
|
unsafe.Pointer(&fdata[0]),
|
||||||
|
C.int(size),
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return r, fmt.Errorf("unsupported K quant")
|
||||||
|
}
|
||||||
|
|
||||||
|
r = C.mlx_array_new_data(
|
||||||
|
unsafe.Pointer(&fdata[0]),
|
||||||
|
&shape[0],
|
||||||
|
C.int(len(shape)),
|
||||||
|
C.MLX_FLOAT16,
|
||||||
|
)
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
643
x/ml/device.go
Normal file
@@ -0,0 +1,643 @@
|
|||||||
|
package ml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"hash/maphash"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GPULayers is a set of layers to be allocated on a single GPU
|
||||||
|
type GPULayers struct {
|
||||||
|
DeviceID
|
||||||
|
|
||||||
|
// Layers is a set of layer indicies to load
|
||||||
|
Layers []int
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
|
||||||
|
func (g GPULayers) FirstLayer() int {
|
||||||
|
if len(g.Layers) == 0 {
|
||||||
|
return math.MaxInt
|
||||||
|
}
|
||||||
|
|
||||||
|
first := g.Layers[0]
|
||||||
|
for i := 1; i < len(g.Layers); i++ {
|
||||||
|
if g.Layers[i] < first {
|
||||||
|
first = g.Layers[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return first
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g GPULayers) String() string {
|
||||||
|
if len(g.Layers) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Sort(g.Layers)
|
||||||
|
|
||||||
|
contiguous := true
|
||||||
|
base := g.Layers[0]
|
||||||
|
for i := range g.Layers {
|
||||||
|
if g.Layers[i] != base+i {
|
||||||
|
contiguous = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if contiguous {
|
||||||
|
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GPULayersList is a set of layer allocations across multiple GPUs
|
||||||
|
type GPULayersList []GPULayers
|
||||||
|
|
||||||
|
func (l GPULayersList) Len() int { return len(l) }
|
||||||
|
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
|
||||||
|
|
||||||
|
// Sort by the ordering of the layers offloaded
|
||||||
|
func (l GPULayersList) Less(i, j int) bool {
|
||||||
|
li := l[i].FirstLayer()
|
||||||
|
lj := l[j].FirstLayer()
|
||||||
|
|
||||||
|
return li < lj
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l GPULayersList) String() string {
|
||||||
|
if l.Sum() > 0 {
|
||||||
|
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%v", []GPULayers(l))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum is the total number of layers assigned across all GPUs
|
||||||
|
func (l GPULayersList) Sum() int {
|
||||||
|
var sum int
|
||||||
|
|
||||||
|
for _, g := range l {
|
||||||
|
sum += len(g.Layers)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
var h maphash.Hash
|
||||||
|
|
||||||
|
// Hash is an identifier of this layer assignment
|
||||||
|
func (l GPULayersList) Hash() uint64 {
|
||||||
|
h.Reset()
|
||||||
|
for _, g := range l {
|
||||||
|
if len(g.Layers) > 0 {
|
||||||
|
h.WriteString(g.ID + g.Library)
|
||||||
|
for _, l := range g.Layers {
|
||||||
|
binary.Write(&h, binary.NativeEndian, int64(l))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.Sum64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
||||||
|
// the attempted memory allocation.
|
||||||
|
type ErrNoMem struct {
|
||||||
|
BackendMemory
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrNoMem) Error() string {
|
||||||
|
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Minimal unique device identification
|
||||||
|
type DeviceID struct {
|
||||||
|
// ID is an identifier for the device for matching with system
|
||||||
|
// management libraries. The ID is only unique for other devices
|
||||||
|
// using the same Library.
|
||||||
|
// This ID represents a "post filtered" view of the enumerated devices
|
||||||
|
// if the ID is numeric
|
||||||
|
ID string `json:"id"`
|
||||||
|
|
||||||
|
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
|
||||||
|
Library string `json:"backend,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceMemory provides a breakdown of the memory needed
|
||||||
|
// per device, such as a CPU or GPU.
|
||||||
|
type DeviceMemory struct {
|
||||||
|
DeviceID
|
||||||
|
|
||||||
|
// Name is the name of the device as labeled by the backend. It
|
||||||
|
// may not be persistent across instances of the runner.
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Weights is the per-layer memory needed for the model weights.
|
||||||
|
Weights []uint64
|
||||||
|
|
||||||
|
// Cache is the per-layer memory needed for the KV cache.
|
||||||
|
Cache []uint64
|
||||||
|
|
||||||
|
// Graph is the size of the compute graph. It is not per-layer.
|
||||||
|
Graph uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func sumMemory(mem []uint64) uint64 {
|
||||||
|
var sum uint64
|
||||||
|
|
||||||
|
for _, m := range mem {
|
||||||
|
sum += m
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the total size of the memory required by this device
|
||||||
|
func (m DeviceMemory) Size() uint64 {
|
||||||
|
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
|
||||||
|
}
|
||||||
|
|
||||||
|
func memoryPresent(mem []uint64) bool {
|
||||||
|
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m DeviceMemory) LogValue() slog.Value {
|
||||||
|
var attrs []slog.Attr
|
||||||
|
if memoryPresent(m.Weights) {
|
||||||
|
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
||||||
|
}
|
||||||
|
|
||||||
|
if memoryPresent(m.Cache) {
|
||||||
|
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Graph != 0 {
|
||||||
|
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(attrs) > 0 && m.ID != "" {
|
||||||
|
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return slog.GroupValue(attrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackendMemory provides the amount of memory required to load the model
|
||||||
|
// per device based on the BackendParams. In some cases, not all required
|
||||||
|
// allocations will be known at this point. However, the size of the most recent
|
||||||
|
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||||
|
// accommodate that to make forward progress.
|
||||||
|
type BackendMemory struct {
|
||||||
|
// InputWeights are always located on the CPU and cannot be moved
|
||||||
|
InputWeights uint64
|
||||||
|
|
||||||
|
// CPU model components are located in system memory. This does not
|
||||||
|
// include unified memory allocated through the GPU.
|
||||||
|
CPU DeviceMemory
|
||||||
|
|
||||||
|
// GPU model components are located on one or more GPUs.
|
||||||
|
GPUs []DeviceMemory
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m BackendMemory) LogValue() slog.Value {
|
||||||
|
var attrs []slog.Attr
|
||||||
|
if m.InputWeights != 0 {
|
||||||
|
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
||||||
|
for _, g := range m.GPUs {
|
||||||
|
attrs = append(attrs, slog.Any(g.Name, g))
|
||||||
|
}
|
||||||
|
|
||||||
|
return slog.GroupValue(attrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log prints a high level summary of the memory
|
||||||
|
func (m BackendMemory) Log(level slog.Level) {
|
||||||
|
var total uint64
|
||||||
|
|
||||||
|
for _, gpu := range m.GPUs {
|
||||||
|
if sum := sumMemory(gpu.Weights); sum > 0 {
|
||||||
|
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||||
|
total += sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
|
||||||
|
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||||
|
total += sum
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gpu := range m.GPUs {
|
||||||
|
if sum := sumMemory(gpu.Cache); sum > 0 {
|
||||||
|
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||||
|
total += sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sum := sumMemory(m.CPU.Cache); sum > 0 {
|
||||||
|
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||||
|
total += sum
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gpu := range m.GPUs {
|
||||||
|
if sum := gpu.Graph; sum > 0 {
|
||||||
|
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||||
|
total += sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sum := m.CPU.Graph; sum > 0 {
|
||||||
|
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||||
|
total += sum
|
||||||
|
}
|
||||||
|
|
||||||
|
if total > 0 {
|
||||||
|
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceInfo struct {
|
||||||
|
DeviceID
|
||||||
|
|
||||||
|
// Name is the name of the device as labeled by the backend. It
|
||||||
|
// may not be persistent across instances of the runner.
|
||||||
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// Description is the longer user-friendly identification of the device
|
||||||
|
Description string `json:"description"`
|
||||||
|
|
||||||
|
// FilterID is populated with the unfiltered device ID if a numeric ID is used
|
||||||
|
// so the device can be included.
|
||||||
|
FilterID string `json:"filter_id,omitempty"`
|
||||||
|
|
||||||
|
// Integrated is set true for integrated GPUs, false for Discrete GPUs
|
||||||
|
Integrated bool `json:"integration,omitempty"`
|
||||||
|
|
||||||
|
// PCIID is the bus, device and domain ID of the device for deduplication
|
||||||
|
// when discovered by multiple backends
|
||||||
|
PCIID string `json:"pci_id,omitempty"`
|
||||||
|
|
||||||
|
// TotalMemory is the total amount of memory the device can use for loading models
|
||||||
|
TotalMemory uint64 `json:"total_memory"`
|
||||||
|
|
||||||
|
// FreeMemory is the amount of memory currently available on the device for loading models
|
||||||
|
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||||
|
|
||||||
|
// ComputeMajor is the major version of capabilities of the device
|
||||||
|
// if unsupported by the backend, -1 will be returned
|
||||||
|
ComputeMajor int
|
||||||
|
|
||||||
|
// ComputeMinor is the minor version of capabilities of the device
|
||||||
|
// if unsupported by the backend, -1 will be returned
|
||||||
|
ComputeMinor int
|
||||||
|
|
||||||
|
// Driver Information
|
||||||
|
DriverMajor int `json:"driver_major,omitempty"`
|
||||||
|
DriverMinor int `json:"driver_minor,omitempty"`
|
||||||
|
|
||||||
|
// Where backends were loaded from
|
||||||
|
LibraryPath []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type SystemInfo struct {
|
||||||
|
// ThreadCount is the optimal number of threads to use for inference
|
||||||
|
ThreadCount int `json:"threads,omitempty"`
|
||||||
|
|
||||||
|
// TotalMemory is the total amount of system memory
|
||||||
|
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||||
|
|
||||||
|
// FreeMemory is the amount of memory currently available on the system for loading models
|
||||||
|
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||||
|
|
||||||
|
// FreeSwap is the amount of system swap space reported as available
|
||||||
|
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DeviceInfo) Compute() string {
|
||||||
|
// AMD gfx is encoded into the major minor in hex form
|
||||||
|
if strings.EqualFold(d.Library, "ROCm") {
|
||||||
|
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
|
||||||
|
}
|
||||||
|
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DeviceInfo) Driver() string {
|
||||||
|
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinimumMemory reports the amount of memory that should be set aside
|
||||||
|
// on the device for overhead (e.g. VRAM consumed by context structures independent
|
||||||
|
// of model allocations)
|
||||||
|
func (d DeviceInfo) MinimumMemory() uint64 {
|
||||||
|
if d.Library == "Metal" {
|
||||||
|
return 512 * format.MebiByte
|
||||||
|
}
|
||||||
|
return 457 * format.MebiByte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by Free Space.
|
||||||
|
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
|
||||||
|
type ByFreeMemory []DeviceInfo
|
||||||
|
|
||||||
|
func (a ByFreeMemory) Len() int { return len(a) }
|
||||||
|
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
|
func (a ByFreeMemory) Less(i, j int) bool {
|
||||||
|
if a[i].Integrated && !a[j].Integrated {
|
||||||
|
return true
|
||||||
|
} else if !a[i].Integrated && a[j].Integrated {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a[i].FreeMemory < a[j].FreeMemory
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByPerformance groups devices by similar speed
|
||||||
|
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
|
||||||
|
resp := [][]DeviceInfo{}
|
||||||
|
scores := []bool{}
|
||||||
|
for _, info := range l {
|
||||||
|
found := false
|
||||||
|
requested := info.Integrated
|
||||||
|
for i, score := range scores {
|
||||||
|
if score == requested {
|
||||||
|
resp[i] = append(resp[i], info)
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
scores = append(scores, requested)
|
||||||
|
resp = append(resp, []DeviceInfo{info})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
|
||||||
|
resp := [][]DeviceInfo{}
|
||||||
|
libs := []string{}
|
||||||
|
for _, info := range l {
|
||||||
|
found := false
|
||||||
|
requested := info.Library
|
||||||
|
for i, lib := range libs {
|
||||||
|
if lib == requested {
|
||||||
|
resp[i] = append(resp[i], info)
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
libs = append(libs, requested)
|
||||||
|
resp = append(resp, []DeviceInfo{info})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func LibraryPaths(l []DeviceInfo) []string {
|
||||||
|
gpuLibs := []string{LibOllamaPath}
|
||||||
|
for _, gpu := range l {
|
||||||
|
for _, dir := range gpu.LibraryPath {
|
||||||
|
needed := true
|
||||||
|
for _, existing := range gpuLibs {
|
||||||
|
if dir == existing {
|
||||||
|
needed = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if needed {
|
||||||
|
gpuLibs = append(gpuLibs, dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return gpuLibs
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceComparison int
|
||||||
|
|
||||||
|
const (
|
||||||
|
UniqueDevice DeviceComparison = iota
|
||||||
|
SameBackendDevice // The device is the same, and the library/backend is the same
|
||||||
|
DuplicateDevice // The same physical device but different library/backend (overlapping device)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||||
|
if a.PCIID != b.PCIID {
|
||||||
|
return UniqueDevice
|
||||||
|
}
|
||||||
|
// If PCIID is empty, we have to use ID + library for uniqueness
|
||||||
|
if a.PCIID == "" && a.DeviceID != b.DeviceID {
|
||||||
|
return UniqueDevice
|
||||||
|
}
|
||||||
|
if a.Library == b.Library {
|
||||||
|
return SameBackendDevice
|
||||||
|
}
|
||||||
|
return DuplicateDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
// For a SameBackendDevice, return true if b is better than a
|
||||||
|
// e.g. newer GPU library version
|
||||||
|
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||||
|
aLib := a.LibraryPath[len(a.LibraryPath)-1]
|
||||||
|
bLib := b.LibraryPath[len(b.LibraryPath)-1]
|
||||||
|
if aLib == bLib {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
aLibSplit := strings.SplitN(aLib, "_", 2)
|
||||||
|
bLibSplit := strings.SplitN(bLib, "_", 2)
|
||||||
|
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if aLibSplit[0] != bLibSplit[0] {
|
||||||
|
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if aLibSplit[1] == bLibSplit[1] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cmp := []string{aLibSplit[1], bLibSplit[1]}
|
||||||
|
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
|
||||||
|
return cmp[0] == bLibSplit[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each GPU, check if it does NOT support flash attention
|
||||||
|
func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||||
|
for _, gpu := range l {
|
||||||
|
supportsFA := gpu.Library == "cpu" ||
|
||||||
|
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||||
|
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
|
||||||
|
gpu.Library == "ROCm" ||
|
||||||
|
gpu.Library == "Vulkan"
|
||||||
|
|
||||||
|
if !supportsFA {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given the list of GPUs this instantiation is targeted for,
|
||||||
|
// figure out the visible devices environment variables
|
||||||
|
// Set mustFilter true to enable filtering of CUDA devices
|
||||||
|
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
|
||||||
|
if len(l) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
env := map[string]string{}
|
||||||
|
for _, d := range l {
|
||||||
|
d.updateVisibleDevicesEnv(env, mustFilter)
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedsInitValidation returns true if the device in question has the potential
|
||||||
|
// to crash at inference time and requires deeper validation before we include
|
||||||
|
// it in the supported devices list.
|
||||||
|
func (d DeviceInfo) NeedsInitValidation() bool {
|
||||||
|
// ROCm: rocblas will crash on unsupported devices.
|
||||||
|
// CUDA: verify CC is supported by the version of the library
|
||||||
|
return d.Library == "ROCm" || d.Library == "CUDA"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the init validation environment variable
|
||||||
|
func (d DeviceInfo) AddInitValidation(env map[string]string) {
|
||||||
|
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreferredLibrary returns true if this library is preferred over the other input
|
||||||
|
// library
|
||||||
|
// Used to filter out Vulkan in favor of CUDA or ROCm
|
||||||
|
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
|
||||||
|
// TODO in the future if we find Vulkan is better than ROCm on some devices
|
||||||
|
// that implementation can live here.
|
||||||
|
|
||||||
|
if d.Library == "CUDA" || d.Library == "ROCm" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
|
||||||
|
var envVar string
|
||||||
|
switch d.Library {
|
||||||
|
case "ROCm":
|
||||||
|
// ROCm must be filtered as it can crash the runner on unsupported devices
|
||||||
|
envVar = "ROCR_VISIBLE_DEVICES"
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
envVar = "HIP_VISIBLE_DEVICES"
|
||||||
|
}
|
||||||
|
case "CUDA":
|
||||||
|
if !mustFilter {
|
||||||
|
// By default we try to avoid filtering CUDA devices because ROCm also
|
||||||
|
// looks at the CUDA env var, and gets confused in mixed vendor environments.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
envVar = "CUDA_VISIBLE_DEVICES"
|
||||||
|
default:
|
||||||
|
// Vulkan is not filtered via env var, but via scheduling decisions
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v, existing := env[envVar]
|
||||||
|
if existing {
|
||||||
|
v = v + ","
|
||||||
|
}
|
||||||
|
if d.FilterID != "" {
|
||||||
|
v = v + d.FilterID
|
||||||
|
} else {
|
||||||
|
v = v + d.ID
|
||||||
|
}
|
||||||
|
env[envVar] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseRunner interface {
|
||||||
|
// GetPort returns the localhost port number the runner is running on
|
||||||
|
GetPort() int
|
||||||
|
|
||||||
|
// HasExited indicates if the runner is no longer running. This can be used during
|
||||||
|
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
|
||||||
|
HasExited() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type RunnerDiscovery interface {
|
||||||
|
BaseRunner
|
||||||
|
|
||||||
|
// GetDeviceInfos will perform a query of the underlying device libraries
|
||||||
|
// for device identification and free VRAM information
|
||||||
|
// During bootstrap scenarios, this routine may take seconds to complete
|
||||||
|
GetDeviceInfos(ctx context.Context) []DeviceInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type FilteredRunnerDiscovery interface {
|
||||||
|
RunnerDiscovery
|
||||||
|
|
||||||
|
// GetActiveDeviceIDs returns the filtered set of devices actively in
|
||||||
|
// use by this runner for running models. If the runner is a bootstrap runner, no devices
|
||||||
|
// will be active yet so no device IDs are returned.
|
||||||
|
// This routine will not query the underlying device and will return immediately
|
||||||
|
GetActiveDeviceIDs() []DeviceID
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
|
||||||
|
var moreDevices []DeviceInfo
|
||||||
|
port := runner.GetPort()
|
||||||
|
tick := time.Tick(10 * time.Millisecond)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
||||||
|
case <-tick:
|
||||||
|
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
// slog.Warn("failed to send request", "error", err)
|
||||||
|
if runner.HasExited() {
|
||||||
|
return nil, fmt.Errorf("runner crashed")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
// old runner, fall back to bootstrapping model
|
||||||
|
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to read response", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
|
||||||
|
return nil, fmt.Errorf("runner error: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &moreDevices); err != nil {
|
||||||
|
slog.Warn("unmarshal encode response", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return moreDevices, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
103
x/ml/nn/attention.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package nn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/kvcache"
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Attention implements scaled dot-product attention for transformer models:
|
||||||
|
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: Context for tensor operations
|
||||||
|
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||||
|
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||||
|
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||||
|
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||||
|
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
//
|
||||||
|
// Attention output with shape [d_v, heads, seq_len_q]
|
||||||
|
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
|
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
|
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
|
ctx.Forward(query)
|
||||||
|
|
||||||
|
if key != nil && value != nil {
|
||||||
|
if query.Dim(0) != key.Dim(0) {
|
||||||
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Dim(1) != value.Dim(1) {
|
||||||
|
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Dim(2) != value.Dim(2) {
|
||||||
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(key, value)
|
||||||
|
if cache != nil {
|
||||||
|
cache.Put(ctx, key, value)
|
||||||
|
}
|
||||||
|
} else if cache == nil {
|
||||||
|
panic("key & value tensors must be provided if cache is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
|
||||||
|
// panic("after cache get") //
|
||||||
|
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||||
|
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
|
||||||
|
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
|
||||||
|
|
||||||
|
// var mask ml.Tensor
|
||||||
|
if cache != nil {
|
||||||
|
key, value, _ = cache.Get(ctx)
|
||||||
|
}
|
||||||
|
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
|
||||||
|
// panic("after cache get") //
|
||||||
|
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||||
|
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
|
||||||
|
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
|
||||||
|
|
||||||
|
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||||
|
// will do any expected backend-specific transformations for us
|
||||||
|
|
||||||
|
if cache != nil {
|
||||||
|
// TODO what to do with vmla?
|
||||||
|
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
|
||||||
|
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
|
||||||
|
|
||||||
|
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
|
||||||
|
} else {
|
||||||
|
panic("else case not supported")
|
||||||
|
// TODO transpose shapes are wrong
|
||||||
|
// key = key.Transpose(ctx, 0, 2, 1, 3)
|
||||||
|
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
|
||||||
|
|
||||||
|
// kq := query.Matmul(ctx, key)
|
||||||
|
|
||||||
|
// kq = kq.Scale(ctx, scale)
|
||||||
|
// if mask != nil {
|
||||||
|
// kq = kq.Add(ctx, mask)
|
||||||
|
// }
|
||||||
|
// kq = kq.Softmax(ctx)
|
||||||
|
|
||||||
|
// kqv := kq.Matmul(ctx, value)
|
||||||
|
|
||||||
|
// if vmla != nil {
|
||||||
|
// kqv = kqv.Matmul(ctx, vmla)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
30
x/ml/nn/convolution.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package nn
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/x/ml"
|
||||||
|
|
||||||
|
type Conv2D struct {
|
||||||
|
Weight ml.Tensor `gguf:"weight"`
|
||||||
|
Bias ml.Tensor `gguf:"bias"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||||
|
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
|
||||||
|
if m.Bias != nil {
|
||||||
|
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
|
||||||
|
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conv3D struct {
|
||||||
|
Weight ml.Tensor `gguf:"weight"`
|
||||||
|
Bias ml.Tensor `gguf:"bias"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
|
||||||
|
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
|
||||||
|
if m.Bias != nil {
|
||||||
|
t = t.Add(ctx, m.Bias)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
11
x/ml/nn/embedding.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package nn
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/x/ml"
|
||||||
|
|
||||||
|
type Embedding struct {
|
||||||
|
Weight ml.Tensor `gguf:"weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||||
|
return m.Weight.TakeAxes(ctx, hiddenState, 0)
|
||||||
|
}
|
||||||
32
x/ml/nn/linear.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package nn
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/x/ml"
|
||||||
|
|
||||||
|
type Linear struct {
|
||||||
|
Weight ml.Tensor `gguf:"weight"`
|
||||||
|
Bias ml.Tensor `gguf:"bias"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||||
|
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
|
||||||
|
if m.Bias != nil {
|
||||||
|
t = t.Add(ctx, m.Bias)
|
||||||
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
type LinearBatch struct {
|
||||||
|
Weight ml.Tensor `gguf:"weight"`
|
||||||
|
Bias ml.Tensor `gguf:"bias"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
|
||||||
|
panic("not yet ported")
|
||||||
|
// t = m.Weight.MulmatID(ctx, t, indices)
|
||||||
|
// if m.Bias != nil {
|
||||||
|
// t = t.AddID(ctx, m.Bias, indices)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return t
|
||||||
|
}
|
||||||
29
x/ml/nn/normalization.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package nn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LayerNorm struct {
|
||||||
|
Weight ml.Tensor `gguf:"weight"`
|
||||||
|
Bias ml.Tensor `gguf:"bias"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||||
|
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
type RMSNorm struct {
|
||||||
|
Weight ml.Tensor `gguf:"weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||||
|
// slog.Info("RMSNorm", "eps", eps)
|
||||||
|
// fmt.Fprintln(os.Stderr, t.ToString())
|
||||||
|
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
|
||||||
|
|
||||||
|
// TODO this is probably model specific, not generalized...
|
||||||
|
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
|
||||||
|
|
||||||
|
return t.RMSNorm(ctx, w, eps)
|
||||||
|
}
|
||||||
41
x/ml/nn/pooling/pooling.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package pooling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Type uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeNone Type = iota
|
||||||
|
TypeMean
|
||||||
|
TypeCLS
|
||||||
|
TypeLast
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t Type) String() string {
|
||||||
|
switch t {
|
||||||
|
case TypeMean:
|
||||||
|
return "Mean"
|
||||||
|
case TypeCLS:
|
||||||
|
return "CLS"
|
||||||
|
case TypeLast:
|
||||||
|
return "Last"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||||
|
switch t {
|
||||||
|
// case TypeMean:
|
||||||
|
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
|
||||||
|
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||||
|
// case TypeCLS:
|
||||||
|
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
|
||||||
|
// case TypeLast:
|
||||||
|
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
|
||||||
|
default:
|
||||||
|
panic("unknown pooling type")
|
||||||
|
}
|
||||||
|
}
|
||||||
72
x/ml/nn/rope/rope.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package rope
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/x/ml"
|
||||||
|
|
||||||
|
// Options contains optional parameters for RoPE function
|
||||||
|
type Options struct {
|
||||||
|
Type int
|
||||||
|
Factors ml.Tensor
|
||||||
|
|
||||||
|
// YaRN options
|
||||||
|
YaRN struct {
|
||||||
|
OriginalContextLength int
|
||||||
|
ExtrapolationFactor,
|
||||||
|
AttentionFactor,
|
||||||
|
BetaFast,
|
||||||
|
BetaSlow float32
|
||||||
|
}
|
||||||
|
|
||||||
|
// MRoPE options
|
||||||
|
MRoPE struct {
|
||||||
|
Sections []int
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTypeNeoX sets RoPE type to NeoX
|
||||||
|
func WithTypeNeoX() func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.Type = 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFactors sets custom rope factors
|
||||||
|
func WithFactors(factors ml.Tensor) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
if factors != nil {
|
||||||
|
opts.Factors = factors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOriginalContextLength sets a custom context length
|
||||||
|
func WithOriginalContextLength(n int) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.YaRN.OriginalContextLength = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.YaRN.ExtrapolationFactor = extrapolationFactor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAttentionFactor(attentionFactor float32) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.YaRN.AttentionFactor = attentionFactor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithMRoPE(sections []int) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.Type |= 1 << 3
|
||||||
|
opts.MRoPE.Sections = sections
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.Type |= 1<<3 | 1<<5
|
||||||
|
opts.MRoPE.Sections = sections
|
||||||
|
}
|
||||||
|
}
|
||||||
56
x/ml/path.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package ml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LibPath is a path to lookup dynamic libraries
|
||||||
|
// in development it's usually 'build/lib/ollama'
|
||||||
|
// in distribution builds it's 'lib/ollama' on Windows
|
||||||
|
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||||
|
// note: distribution builds, additional GPU-specific libraries are
|
||||||
|
// found in subdirectories of the returned path, such as
|
||||||
|
// 'cuda_v12', 'rocm', etc.
|
||||||
|
var LibOllamaPath string = func() string {
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||||
|
exe = eval
|
||||||
|
}
|
||||||
|
|
||||||
|
var libPath string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
|
||||||
|
case "linux":
|
||||||
|
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
|
||||||
|
case "darwin":
|
||||||
|
libPath = filepath.Dir(exe)
|
||||||
|
}
|
||||||
|
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
paths := []string{
|
||||||
|
libPath,
|
||||||
|
|
||||||
|
// build paths for development
|
||||||
|
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
|
||||||
|
filepath.Join(cwd, "build", "lib", "ollama"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range paths {
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Dir(exe)
|
||||||
|
}()
|
||||||
282
x/model/bytepairencoding.go
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"log/slog"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/dlclark/regexp2"
|
||||||
|
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BytePairEncoding struct {
|
||||||
|
vocab *Vocabulary
|
||||||
|
regexps []*regexp2.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||||
|
|
||||||
|
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||||
|
if len(pretokenizers) == 0 {
|
||||||
|
// set default byte-level pretokenizer if none provided, e.g.
|
||||||
|
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||||
|
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||||
|
}
|
||||||
|
|
||||||
|
return BytePairEncoding{
|
||||||
|
vocab: vocab,
|
||||||
|
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||||
|
for _, p := range pretokenizers {
|
||||||
|
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||||
|
return bpe.vocab
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||||
|
return bpe.vocab.Is(id, special)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||||
|
parts := []string{s}
|
||||||
|
for _, re := range bpe.regexps {
|
||||||
|
parts = slices.Collect(func(yield func(string) bool) {
|
||||||
|
for _, part := range parts {
|
||||||
|
r := []rune(part)
|
||||||
|
var offset int
|
||||||
|
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||||
|
if offset-m.Index != 0 {
|
||||||
|
if !yield(string(r[:m.Index])) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(m.String()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
offset = m.Index + m.Length
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset < len(r) {
|
||||||
|
if !yield(string(r[offset:])) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices.Values(parts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fragment is a string fragment and their corresponding token IDs
|
||||||
|
type fragment struct {
|
||||||
|
value string
|
||||||
|
ids []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// pair is a pair of runes and its rank
|
||||||
|
type pair struct {
|
||||||
|
a, b int
|
||||||
|
rank int
|
||||||
|
value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type merge struct {
|
||||||
|
p, n int
|
||||||
|
runes []rune
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
|
fragments := []fragment{{value: s}}
|
||||||
|
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||||
|
// TODO: process special tokens concurrently
|
||||||
|
id := bpe.vocab.Encode(special)
|
||||||
|
for i := 0; i < len(fragments); i++ {
|
||||||
|
frag := fragments[i]
|
||||||
|
if len(frag.ids) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var middle []fragment
|
||||||
|
switch i := strings.Index(frag.value, special); {
|
||||||
|
case i < 0:
|
||||||
|
middle = append(middle, frag)
|
||||||
|
case i > 0:
|
||||||
|
middle = append(middle, fragment{value: frag.value[:i]})
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||||
|
if rest := frag.value[i+len(special):]; rest != "" {
|
||||||
|
middle = append(middle, fragment{value: rest})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var ids []int32
|
||||||
|
for _, frag := range fragments {
|
||||||
|
if len(frag.ids) > 0 {
|
||||||
|
ids = append(ids, frag.ids...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for split := range bpe.split(frag.value) {
|
||||||
|
// TODO: process splits concurrently
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, b := range []byte(split) {
|
||||||
|
r := rune(b)
|
||||||
|
switch {
|
||||||
|
case r == 0x00ad:
|
||||||
|
r = 0x0143
|
||||||
|
case r <= 0x0020:
|
||||||
|
r = r + 0x0100
|
||||||
|
case r >= 0x007f && r <= 0x00a0:
|
||||||
|
r = r + 0x00a2
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteRune(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// short circuit if the fragment is in the vocabulary
|
||||||
|
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||||
|
ids = append(ids, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
runes := []rune(sb.String())
|
||||||
|
merges := make([]merge, len(runes))
|
||||||
|
for r := range runes {
|
||||||
|
merges[r] = merge{
|
||||||
|
p: r - 1,
|
||||||
|
n: r + 1,
|
||||||
|
runes: []rune{runes[r]},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pairwise := func(a, b int) *pair {
|
||||||
|
if a < 0 || b >= len(runes) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||||
|
rank := bpe.vocab.Merge(left, right)
|
||||||
|
if rank < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pair{
|
||||||
|
a: a,
|
||||||
|
b: b,
|
||||||
|
rank: rank,
|
||||||
|
value: left + right,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pairs := heap.NewWith(func(i, j *pair) int {
|
||||||
|
return cmp.Compare(i.rank, j.rank)
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range len(runes) - 1 {
|
||||||
|
if pair := pairwise(i, i+1); pair != nil {
|
||||||
|
pairs.Push(pair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for !pairs.Empty() {
|
||||||
|
pair, _ := pairs.Pop()
|
||||||
|
|
||||||
|
left, right := merges[pair.a], merges[pair.b]
|
||||||
|
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||||
|
string(left.runes)+string(right.runes) != pair.value {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||||
|
merges[pair.b].runes = nil
|
||||||
|
|
||||||
|
merges[pair.a].n = right.n
|
||||||
|
if right.n < len(merges) {
|
||||||
|
merges[right.n].p = pair.a
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||||
|
pairs.Push(pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||||
|
pairs.Push(pair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, merge := range merges {
|
||||||
|
if len(merge.runes) > 0 {
|
||||||
|
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||||
|
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if addSpecial {
|
||||||
|
ids = bpe.vocab.addSpecials(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type lazyIdsString struct {
|
||||||
|
ids []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l lazyIdsString) LogValue() slog.Value {
|
||||||
|
return slog.AnyValue(fmt.Sprint(l.ids))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, id := range ids {
|
||||||
|
for _, r := range bpe.vocab.Decode(id) {
|
||||||
|
switch {
|
||||||
|
case r == 0x0100:
|
||||||
|
// this produces 0x00 aka NULL
|
||||||
|
continue
|
||||||
|
case r == 0x0143:
|
||||||
|
r = 0x00ad
|
||||||
|
case r > 0x0100 && r <= 0x0120:
|
||||||
|
r = r - 0x0100
|
||||||
|
case r > 0x0120 && r <= 0x0142:
|
||||||
|
r = r - 0x00a2
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||||
|
// encoding of the rune which is _not_ what we want
|
||||||
|
if err := sb.WriteByte(byte(r)); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
322
x/model/bytepairencoding_test.go
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func llama(t testing.TB) BytePairEncoding {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
vocab := make(map[string]int32)
|
||||||
|
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
types := make([]int32, len(vocab))
|
||||||
|
tokens := make([]string, len(vocab))
|
||||||
|
for token, id := range vocab {
|
||||||
|
tokens[id] = token
|
||||||
|
types[id] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||||
|
if _, ok := vocab[token]; !ok {
|
||||||
|
tokens = append(tokens, token) //nolint:makezero
|
||||||
|
types = append(types, 3) //nolint:makezero
|
||||||
|
vocab[token] = int32(len(vocab))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
merges := make([]string, 0, 50000)
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
if !strings.HasPrefix(scanner.Text(), "#") {
|
||||||
|
merges = append(merges, scanner.Text())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewBytePairEncoding(
|
||||||
|
&Vocabulary{
|
||||||
|
Values: tokens,
|
||||||
|
Types: types,
|
||||||
|
Merges: merges,
|
||||||
|
},
|
||||||
|
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLlama(t *testing.T) {
|
||||||
|
tokenizer := llama(t)
|
||||||
|
|
||||||
|
t.Run("simple", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ids, err := tokenizer.Encode("hello world", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
|
||||||
|
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := tokenizer.Decode([]int32{15339, 1917})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s != "hello world" {
|
||||||
|
t.Errorf("got %q, want hello world", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
|
||||||
|
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("simple repeated", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := map[string][]int32{
|
||||||
|
strings.Repeat("0", 1): {15},
|
||||||
|
strings.Repeat("0", 2): {410},
|
||||||
|
strings.Repeat("0", 3): {931},
|
||||||
|
strings.Repeat("0", 4): {931, 15},
|
||||||
|
strings.Repeat("0", 5): {931, 410},
|
||||||
|
strings.Repeat("0", 6): {931, 931},
|
||||||
|
strings.Repeat("0", 7): {931, 931, 15},
|
||||||
|
strings.Repeat("0", 8): {931, 931, 410},
|
||||||
|
strings.Repeat("0", 9): {931, 931, 931},
|
||||||
|
strings.Repeat("0", 10): {931, 931, 931, 15},
|
||||||
|
strings.Repeat("0", 11): {931, 931, 931, 410},
|
||||||
|
strings.Repeat("0", 12): {931, 931, 931, 931},
|
||||||
|
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
|
||||||
|
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
|
||||||
|
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
|
||||||
|
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
|
||||||
|
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
|
||||||
|
}
|
||||||
|
|
||||||
|
for s, want := range cases {
|
||||||
|
ids, err := tokenizer.Encode(s, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, ids); diff != "" {
|
||||||
|
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("basic roundtrip", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := []string{
|
||||||
|
"hello",
|
||||||
|
"hello ",
|
||||||
|
"hello ",
|
||||||
|
" hello",
|
||||||
|
" hello ",
|
||||||
|
" hello ",
|
||||||
|
"hello world",
|
||||||
|
"请考试我的软件!12345",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range cases {
|
||||||
|
ids, err := tokenizer.Encode(want, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, err := tokenizer.Decode(ids); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if got != want {
|
||||||
|
t.Errorf("got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("special", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := map[string][]int32{
|
||||||
|
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||||
|
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||||
|
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||||
|
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||||
|
}
|
||||||
|
|
||||||
|
for s, want := range cases {
|
||||||
|
ids, err := tokenizer.Encode(s, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, ids); diff != "" {
|
||||||
|
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("split", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := map[string][]string{
|
||||||
|
"Hello World!": {"Hello", " World", "!"},
|
||||||
|
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
|
||||||
|
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
|
||||||
|
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
|
||||||
|
"Hello World": {"Hello", " ", " World"},
|
||||||
|
"Hello\nWorld": {"Hello", "\n", "World"},
|
||||||
|
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for s, want := range cases {
|
||||||
|
got := slices.Collect(tokenizer.split(s))
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for b := 0x00; b <= 0xFF; b++ {
|
||||||
|
input := string(rune(b))
|
||||||
|
ids, err := tokenizer.Encode(input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := tokenizer.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if b == 0x00 {
|
||||||
|
if len(decoded) != 0 {
|
||||||
|
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded != input {
|
||||||
|
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||||
|
tokenizer := llama(b)
|
||||||
|
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range 8 {
|
||||||
|
n := min(int(math.Pow10(i)), len(bts))
|
||||||
|
bts := bts[:n]
|
||||||
|
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, err := tokenizer.Encode(string(bts), true)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||||
|
ids, err := tokenizer.Encode(string(bts), true)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, err := tokenizer.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
slices.Collect(tokenizer.split(string(bts)))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplit(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
patterns,
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode",
|
||||||
|
patterns: []string{
|
||||||
|
"\\p{N}{1,3}",
|
||||||
|
`[一-龥-ゟ゠-ヿ]+`,
|
||||||
|
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
},
|
||||||
|
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "individual digits",
|
||||||
|
patterns: []string{
|
||||||
|
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
},
|
||||||
|
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
|
||||||
|
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
|
||||||
|
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
76
x/model/input/input.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package input
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/x/ml"
|
||||||
|
|
||||||
|
// Multimodal is a multimodal embedding or a component of one.
|
||||||
|
// For example, it could be a row of an image that can be processed
|
||||||
|
// independently.
|
||||||
|
type Multimodal struct {
|
||||||
|
// Tensor is the embedding data. Implementations may chose what to
|
||||||
|
// store here or it may be nil if not needed. However, any ml.Tensor
|
||||||
|
// objects must be stored here and not in Data.
|
||||||
|
Tensor ml.Tensor
|
||||||
|
|
||||||
|
// Data is implementation-specific opaque data, such as metadata on how
|
||||||
|
// to layout Tensor. It may be nil if not needed. It may also store larger
|
||||||
|
// objects such as complete images if they are to be processed later.
|
||||||
|
Data any
|
||||||
|
}
|
||||||
|
|
||||||
|
// Input represents one token in the input stream
|
||||||
|
type Input struct {
|
||||||
|
// Token is a single element of text.
|
||||||
|
Token int32
|
||||||
|
|
||||||
|
// Multimodal is represents a non-text element such as an
|
||||||
|
// image (or part of one if the image can be processed in pieces).
|
||||||
|
// It may be used either together with Token or on its own.
|
||||||
|
Multimodal []Multimodal
|
||||||
|
|
||||||
|
// MultimodalHash is a unique representation of the data
|
||||||
|
// stored in Multimodal, used for caching and comparing
|
||||||
|
// equality.
|
||||||
|
MultimodalHash uint64
|
||||||
|
|
||||||
|
// SameBatch forces the following number of tokens to be processed
|
||||||
|
// in a single batch, breaking and extending batches as needed.
|
||||||
|
// Useful for things like images that must be processed in one
|
||||||
|
// shot.
|
||||||
|
SameBatch int
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultimodalIndex is a multimodal element (such as an image)
|
||||||
|
// together with an index into the slice of Inputs with the
|
||||||
|
// corresponding token. Note that the index is not the same
|
||||||
|
// as the position - to find that use the index with the
|
||||||
|
// Positions slice.
|
||||||
|
type MultimodalIndex struct {
|
||||||
|
Index int
|
||||||
|
Multimodal []Multimodal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch contains the inputs for a model forward pass
|
||||||
|
type Batch struct {
|
||||||
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
|
Inputs ml.Tensor
|
||||||
|
|
||||||
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
|
// be returned.
|
||||||
|
Outputs ml.Tensor
|
||||||
|
|
||||||
|
// TODO maybe not the optimal way to handle this
|
||||||
|
// Offset of final tensor in the final batch
|
||||||
|
Offset int
|
||||||
|
|
||||||
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
|
// in length to Inputs.
|
||||||
|
Positions []int32
|
||||||
|
|
||||||
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
|
Sequences []int
|
||||||
|
|
||||||
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
|
// models or for batches without multimodal elements.
|
||||||
|
Multimodal []MultimodalIndex
|
||||||
|
}
|
||||||
333
x/model/model.go
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
_ "image/jpeg"
|
||||||
|
_ "image/png"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
_ "golang.org/x/image/bmp"
|
||||||
|
_ "golang.org/x/image/tiff"
|
||||||
|
_ "golang.org/x/image/webp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
"github.com/ollama/ollama/x/kvcache"
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
_ "github.com/ollama/ollama/x/ml/backend"
|
||||||
|
"github.com/ollama/ollama/x/ml/nn/pooling"
|
||||||
|
"github.com/ollama/ollama/x/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||||
|
ErrUnsupportedModel = errors.New("model not supported")
|
||||||
|
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
|
type Model interface {
|
||||||
|
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||||
|
|
||||||
|
Backend() ml.Backend
|
||||||
|
Config() config
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultimodalProcessor must be implemented by multimodal models.
|
||||||
|
type MultimodalProcessor interface {
|
||||||
|
// EncodeMultimodal processes a single input (such as an image) and
|
||||||
|
// generates an output (typically an embedding) that can be used by the model.
|
||||||
|
//
|
||||||
|
// The return value is one or more tensors, each with optional model-specific
|
||||||
|
// opaque metadata. Typically, the tensors might be views into an embedding
|
||||||
|
// with each view representing a chunk of data that can be processed independently
|
||||||
|
// in different batches.
|
||||||
|
//
|
||||||
|
// The result may be cached by the runner.
|
||||||
|
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
|
||||||
|
|
||||||
|
// PostTokenize is called after tokenization to allow the model to edit the
|
||||||
|
// input stream to correctly arrange multimodal elements.
|
||||||
|
//
|
||||||
|
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
|
||||||
|
// in the order that the user provided them. Each element of the slice will be
|
||||||
|
// either a single token or single multimodal object.
|
||||||
|
//
|
||||||
|
// The model must ensure that inputs are stored according to how they will be
|
||||||
|
// processed and stored in the cache. For example, Llava-style models should insert
|
||||||
|
// placeholder tokens equal to the feature size of the corresponding image with
|
||||||
|
// the image itself attached to and split across these tokens. When Forward is called
|
||||||
|
// a partial subset of these tokens may be submitted according to the batch size.
|
||||||
|
//
|
||||||
|
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||||
|
// that is modified to ensure that there is a unique hash value that accurately
|
||||||
|
// represents the contents.
|
||||||
|
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base implements the common fields and methods for all models
|
||||||
|
type Base struct {
|
||||||
|
b ml.Backend
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
Cache kvcache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend returns the underlying backend that will run the model
|
||||||
|
func (m *Base) Backend() ml.Backend {
|
||||||
|
return m.b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Base) Config() config {
|
||||||
|
return m.config
|
||||||
|
}
|
||||||
|
|
||||||
|
var models = make(map[string]func(fs.Config) (Model, error))
|
||||||
|
|
||||||
|
// Register registers a model constructor for the given architecture
|
||||||
|
func Register(name string, f func(fs.Config) (Model, error)) {
|
||||||
|
if _, ok := models[name]; ok {
|
||||||
|
panic("model: model already registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
models[name] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||||
|
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||||
|
b, err := ml.NewBackend(modelPath, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := modelForArch(b.Config())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
base := Base{b: b, config: m.Config()}
|
||||||
|
v := reflect.ValueOf(m)
|
||||||
|
v.Elem().Set(populateFields(base, v.Elem()))
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||||
|
r, err := os.Open(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
meta, err := fsggml.Decode(r, -1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := modelForArch(meta.KV())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tp, ok := m.(TextProcessor)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrUnsupportedTokenizer
|
||||||
|
}
|
||||||
|
return tp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func modelForArch(c fs.Config) (Model, error) {
|
||||||
|
arch := c.Architecture()
|
||||||
|
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||||
|
arch = arch + "_embed"
|
||||||
|
}
|
||||||
|
|
||||||
|
f, ok := models[arch]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrUnsupportedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
return f(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
|
t := v.Type()
|
||||||
|
|
||||||
|
if t.Kind() == reflect.Struct {
|
||||||
|
allNil := true
|
||||||
|
for i := range t.NumField() {
|
||||||
|
tt := t.Field(i).Type
|
||||||
|
vv := v.Field(i)
|
||||||
|
if !vv.CanSet() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// make a copy
|
||||||
|
tagsCopy := tags
|
||||||
|
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
||||||
|
tagsCopy = append(tagsCopy, parseTag(tag))
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||||
|
vv.Set(reflect.ValueOf(base))
|
||||||
|
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||||
|
var fn func([]Tag, string, string) [][]string
|
||||||
|
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
||||||
|
if len(tags) > 0 {
|
||||||
|
var names []string
|
||||||
|
if tags[0].name != "" {
|
||||||
|
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
|
||||||
|
names = append(names, prefix+n+suffix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
|
||||||
|
if len(names) == 0 {
|
||||||
|
// current tag has no name, use child names only
|
||||||
|
fullNames = append(fullNames, childNames...)
|
||||||
|
} else if len(childNames) == 0 {
|
||||||
|
// current tag has names but no children, create branches for each name
|
||||||
|
for _, name := range names {
|
||||||
|
fullNames = append(fullNames, []string{name})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// merge each name with each child
|
||||||
|
for _, name := range names {
|
||||||
|
for _, childName := range childNames {
|
||||||
|
fullNames = append(fullNames, append([]string{name}, childName...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullNames
|
||||||
|
}
|
||||||
|
|
||||||
|
names := fn(tagsCopy, "", "")
|
||||||
|
for _, name := range names {
|
||||||
|
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||||
|
logutil.Trace("found tensor", "", tensor)
|
||||||
|
vv.Set(reflect.ValueOf(tensor))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
||||||
|
setPointer(base, vv, tagsCopy)
|
||||||
|
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||||
|
for i := range vv.Len() {
|
||||||
|
vvv := vv.Index(i)
|
||||||
|
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||||
|
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
|
||||||
|
} else {
|
||||||
|
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !canNil(tt) || !vv.IsNil() {
|
||||||
|
allNil = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if allNil {
|
||||||
|
return reflect.Zero(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||||
|
vv := v
|
||||||
|
if v.Kind() == reflect.Interface {
|
||||||
|
if v.IsNil() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vv = vv.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
vv = reflect.Indirect(vv)
|
||||||
|
if v.IsNil() {
|
||||||
|
vv = reflect.New(v.Type().Elem()).Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if f := populateFields(base, vv, tags...); f.CanAddr() {
|
||||||
|
v.Set(f.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tag struct {
|
||||||
|
name,
|
||||||
|
// prefix and suffix are applied to child tags
|
||||||
|
prefix,
|
||||||
|
suffix string
|
||||||
|
alternatives []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTag(s string) (tag Tag) {
|
||||||
|
parts := strings.Split(s, ",")
|
||||||
|
if len(parts) > 0 {
|
||||||
|
tag.name = parts[0]
|
||||||
|
|
||||||
|
for _, part := range parts[1:] {
|
||||||
|
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
|
||||||
|
// elevate alternative to primary if no primary given
|
||||||
|
tag.name = value
|
||||||
|
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
|
||||||
|
} else if ok {
|
||||||
|
tag.alternatives = append(tag.alternatives, value)
|
||||||
|
}
|
||||||
|
if value, ok := strings.CutPrefix(part, "pre:"); ok {
|
||||||
|
tag.prefix = value
|
||||||
|
}
|
||||||
|
if value, ok := strings.CutPrefix(part, "suf:"); ok {
|
||||||
|
tag.suffix = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func canNil(t reflect.Type) bool {
|
||||||
|
return t.Kind() == reflect.Chan ||
|
||||||
|
t.Kind() == reflect.Func ||
|
||||||
|
t.Kind() == reflect.Interface ||
|
||||||
|
t.Kind() == reflect.Map ||
|
||||||
|
t.Kind() == reflect.Pointer ||
|
||||||
|
t.Kind() == reflect.Slice
|
||||||
|
}
|
||||||
|
|
||||||
|
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
if len(batch.Positions) != len(batch.Sequences) {
|
||||||
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(batch.Positions) < 1 {
|
||||||
|
return nil, errors.New("batch size cannot be less than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := m.Config().Cache
|
||||||
|
if cache != nil {
|
||||||
|
err := cache.StartForward(ctx, batch, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := m.Forward(ctx, batch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(t)
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
58
x/model/models/gemma3/embed.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package gemma3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
"github.com/ollama/ollama/x/ml/nn"
|
||||||
|
"github.com/ollama/ollama/x/ml/nn/pooling"
|
||||||
|
"github.com/ollama/ollama/x/model"
|
||||||
|
"github.com/ollama/ollama/x/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type embedModel struct {
|
||||||
|
model.Base
|
||||||
|
model.SentencePiece
|
||||||
|
|
||||||
|
*TextModel
|
||||||
|
poolingType pooling.Type
|
||||||
|
|
||||||
|
Dense [2]*nn.Linear `gguf:"dense"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
|
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||||
|
for _, dense := range m.Dense {
|
||||||
|
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||||
|
}
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||||
|
m := &embedModel{
|
||||||
|
SentencePiece: model.NewSentencePiece(
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{
|
||||||
|
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||||
|
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||||
|
},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
TextModel: newTextModel(c),
|
||||||
|
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
157
x/model/models/gemma3/model.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package gemma3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"image"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/x/kvcache"
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
"github.com/ollama/ollama/x/ml/nn"
|
||||||
|
"github.com/ollama/ollama/x/model"
|
||||||
|
"github.com/ollama/ollama/x/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.SentencePiece
|
||||||
|
|
||||||
|
*VisionModel `gguf:"vision_tower.vision_model"`
|
||||||
|
*TextModel `gguf:"language_model.model"`
|
||||||
|
|
||||||
|
*MultiModalProjector `gguf:"multi_modal_projector"`
|
||||||
|
|
||||||
|
ImageProcessor
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||||
|
|
||||||
|
type MultiModalProjector struct {
|
||||||
|
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
|
||||||
|
InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight
|
||||||
|
|
||||||
|
tokensPerImage int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
|
||||||
|
l := visionOutputs.Dim(0)
|
||||||
|
|
||||||
|
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||||
|
patchesPerImage := imageSize / patchSize
|
||||||
|
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
|
||||||
|
|
||||||
|
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
|
||||||
|
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
|
||||||
|
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
|
||||||
|
visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||||
|
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
|
||||||
|
|
||||||
|
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
|
||||||
|
visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false))
|
||||||
|
return visionOutputs
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
// slog.Info("XXX Config", "c", c)
|
||||||
|
m := Model{
|
||||||
|
SentencePiece: model.NewSentencePiece(
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{
|
||||||
|
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||||
|
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||||
|
},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ImageProcessor: newImageProcessor(c),
|
||||||
|
VisionModel: newVisionModel(c),
|
||||||
|
TextModel: newTextModel(c),
|
||||||
|
MultiModalProjector: &MultiModalProjector{
|
||||||
|
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||||
|
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||||
|
|
||||||
|
// TODO need to implement sliding window...
|
||||||
|
m.Cache = kvcache.NewMLXCausalCache()
|
||||||
|
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||||
|
if len(m.VisionModel.Layers) == 0 {
|
||||||
|
return nil, model.ErrNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
|
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pixelValues := ctx.Input().FromFloats(f32s,
|
||||||
|
m.ImageProcessor.imageSize,
|
||||||
|
m.ImageProcessor.imageSize,
|
||||||
|
m.ImageProcessor.numChannels,
|
||||||
|
)
|
||||||
|
|
||||||
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||||
|
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
||||||
|
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
|
var result []*input.Input
|
||||||
|
|
||||||
|
for _, inp := range inputs {
|
||||||
|
if len(inp.Multimodal) == 0 {
|
||||||
|
result = append(result, inp)
|
||||||
|
} else {
|
||||||
|
inputMultimodal := inp.Multimodal[0].Tensor
|
||||||
|
|
||||||
|
result = append(result,
|
||||||
|
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||||
|
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||||
|
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||||
|
)
|
||||||
|
|
||||||
|
// add image token placeholders
|
||||||
|
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||||
|
|
||||||
|
result = append(result,
|
||||||
|
&input.Input{Token: 256000}, // <end_of_image>
|
||||||
|
&input.Input{Token: 108}, // "\n\n"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("gemma3", New)
|
||||||
|
model.Register("gemma3_embed", newEmbedModel)
|
||||||
|
}
|
||||||
211
x/model/models/gemma3/model_text.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package gemma3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/x/kvcache"
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
"github.com/ollama/ollama/x/ml/nn"
|
||||||
|
"github.com/ollama/ollama/x/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TextConfig struct {
|
||||||
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
attnKeyLen int
|
||||||
|
eps, ropeScale float32
|
||||||
|
ropeLocalBase, ropeGlobalBase float32
|
||||||
|
largeModelScaling bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextModel struct {
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
|
||||||
|
Layers []TextLayer `gguf:"layers"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"norm"`
|
||||||
|
Output *nn.Linear `gguf:"embed_tokens"`
|
||||||
|
|
||||||
|
*TextConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
gemmaGlobalCacheCount = 6
|
||||||
|
gemma27BLayerCount = 62
|
||||||
|
)
|
||||||
|
|
||||||
|
// const (
|
||||||
|
// cacheTypeSWA = iota
|
||||||
|
// cacheTypeCausal
|
||||||
|
// )
|
||||||
|
|
||||||
|
func newTextModel(c fs.Config) *TextModel {
|
||||||
|
numBlocks := int(c.Uint("block_count"))
|
||||||
|
|
||||||
|
m := TextModel{
|
||||||
|
Layers: make([]TextLayer, numBlocks),
|
||||||
|
TextConfig: &TextConfig{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
|
||||||
|
numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
|
||||||
|
attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
|
||||||
|
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
|
||||||
|
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
|
||||||
|
ropeScale: 1, // 1 - default is 1, implied in python code
|
||||||
|
// vocabSize: vocabSize, // 262144
|
||||||
|
// attnValLen: int(c.Uint("attention.value_length", 256)), //256
|
||||||
|
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||||
|
// (8 instead of 1)
|
||||||
|
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if numBlocks == gemma27BLayerCount {
|
||||||
|
m.largeModelScaling = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return &m
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextSelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"q_proj"`
|
||||||
|
QueryNorm *nn.RMSNorm `gguf:"q_norm"`
|
||||||
|
Key *nn.Linear `gguf:"k_proj"`
|
||||||
|
KeyNorm *nn.RMSNorm `gguf:"k_norm"`
|
||||||
|
Value *nn.Linear `gguf:"v_proj"`
|
||||||
|
Output *nn.Linear `gguf:"o_proj"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||||
|
B := hiddenState.Dim(0)
|
||||||
|
L := hiddenState.Dim(1)
|
||||||
|
ropeBase := opts.ropeLocalBase
|
||||||
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
|
ropeBase = opts.ropeGlobalBase
|
||||||
|
}
|
||||||
|
|
||||||
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
|
||||||
|
k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
|
||||||
|
v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
|
||||||
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||||
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||||
|
traditional := false
|
||||||
|
q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||||
|
k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||||
|
|
||||||
|
// TODO - this is wrong somehow so commenting out
|
||||||
|
// if opts.largeModelScaling {
|
||||||
|
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
|
// } else {
|
||||||
|
// q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||||
|
// }
|
||||||
|
|
||||||
|
scaleFactor := math.Pow(256, -0.5)
|
||||||
|
|
||||||
|
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||||
|
kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
|
||||||
|
return sa.Output.Forward(ctx, kqv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
// ropeBase := m.TextConfig.ropeLocalBase
|
||||||
|
// if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
|
// ropeBase = m.TextConfig.ropeGlobalBase
|
||||||
|
// }
|
||||||
|
// q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
|
||||||
|
panic("not yet implemented")
|
||||||
|
// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextMLP struct {
|
||||||
|
Up *nn.Linear `gguf:"up_proj"`
|
||||||
|
Down *nn.Linear `gguf:"down_proj"`
|
||||||
|
Gate *nn.Linear `gguf:"gate_proj"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||||
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextLayer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
|
||||||
|
SelfAttention *TextSelfAttention `gguf:"self_attn"`
|
||||||
|
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
|
||||||
|
MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
|
||||||
|
MLP *TextMLP `gguf:"mlp"`
|
||||||
|
PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||||
|
residual := hiddenState
|
||||||
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
|
||||||
|
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
|
||||||
|
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||||
|
// we need logits for.
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
|
||||||
|
residual = residual.TakeAxes(ctx, outputs, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
|
residual = hiddenState
|
||||||
|
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
|
||||||
|
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
return hiddenState.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||||
|
|
||||||
|
// set image embeddings
|
||||||
|
// var except []int
|
||||||
|
// for _, image := range batch.Multimodal {
|
||||||
|
// visionOutputs := image.Multimodal[0].Tensor
|
||||||
|
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
|
||||||
|
// []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
|
||||||
|
// []int{image.Index * hiddenState.Stride(1)}, 0)))
|
||||||
|
|
||||||
|
// for i := range visionOutputs.Dim(1) {
|
||||||
|
// except = append(except, image.Index+i)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
// gemma alternates between the sliding window (local) and causal (global)
|
||||||
|
// kv cache every 6 layers
|
||||||
|
if cache != nil {
|
||||||
|
// cacheType := cacheTypeSWA
|
||||||
|
// if (i+1)%gemmaGlobalCacheCount == 0 {
|
||||||
|
// cacheType = cacheTypeCausal
|
||||||
|
// }
|
||||||
|
cache.SetLayer(i)
|
||||||
|
|
||||||
|
// TODO this needs to come back
|
||||||
|
// wc := cache.(*kvcache.WrapperCache)
|
||||||
|
// wc.SetLayerType(cacheType)
|
||||||
|
|
||||||
|
// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||||
|
// causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
var offset int
|
||||||
|
var lastLayerOutputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
offset = batch.Offset
|
||||||
|
lastLayerOutputs = batch.Outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
|
||||||
|
}
|
||||||
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
return hiddenState
|
||||||
|
}
|
||||||
121
x/model/models/gemma3/model_vision.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package gemma3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/x/ml"
|
||||||
|
"github.com/ollama/ollama/x/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
var batchSize int = 1
|
||||||
|
|
||||||
|
type VisionSelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"self_attn.q_proj"`
|
||||||
|
Key *nn.Linear `gguf:"self_attn.k_proj"`
|
||||||
|
Value *nn.Linear `gguf:"self_attn.v_proj"`
|
||||||
|
Output *nn.Linear `gguf:"self_attn.out_proj"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
|
|
||||||
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
|
||||||
|
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||||
|
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||||
|
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||||
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||||
|
|
||||||
|
hiddenState = sa.Output.Forward(ctx, attention)
|
||||||
|
return hiddenState
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionMLP struct {
|
||||||
|
FC1 *nn.Linear `gguf:"fc1"`
|
||||||
|
FC2 *nn.Linear `gguf:"fc2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
|
||||||
|
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
|
||||||
|
return hiddenState
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionEncoderLayer struct {
|
||||||
|
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
|
||||||
|
SelfAttention *VisionSelfAttention
|
||||||
|
|
||||||
|
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
|
||||||
|
MLP *VisionMLP `gguf:"mlp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
|
residual := hiddenState
|
||||||
|
|
||||||
|
// self attention
|
||||||
|
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||||
|
hiddenState = hiddenState.Add(ctx, residual)
|
||||||
|
residual = hiddenState
|
||||||
|
|
||||||
|
// feed forward
|
||||||
|
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||||
|
return hiddenState.Add(ctx, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionModelOptions struct {
|
||||||
|
hiddenSize, numHeads int
|
||||||
|
imageSize, patchSize int
|
||||||
|
eps float32
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionModel struct {
|
||||||
|
PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"`
|
||||||
|
PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"`
|
||||||
|
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
|
||||||
|
|
||||||
|
Layers []VisionEncoderLayer `gguf:"encoder.layers"`
|
||||||
|
|
||||||
|
*VisionModelOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||||
|
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
|
||||||
|
|
||||||
|
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||||
|
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||||
|
hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||||
|
|
||||||
|
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32)
|
||||||
|
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
|
||||||
|
|
||||||
|
for _, layer := range m.Layers {
|
||||||
|
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
return hiddenState
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVisionModel(c fs.Config) *VisionModel {
|
||||||
|
return &VisionModel{
|
||||||
|
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||||
|
VisionModelOptions: &VisionModelOptions{
|
||||||
|
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||||
|
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||||
|
|
||||||
|
imageSize: int(c.Uint("vision.image_size")),
|
||||||
|
patchSize: int(c.Uint("vision.patch_size")),
|
||||||
|
|
||||||
|
eps: c.Float("vision.attention.layer_norm_epsilon"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
60
x/model/models/gemma3/process_image.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package gemma3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"image"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/model/imageproc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageProcessor struct {
|
||||||
|
imageSize, patchSize, numChannels int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||||
|
return ImageProcessor{
|
||||||
|
imageSize: int(c.Uint("vision.image_size")),
|
||||||
|
patchSize: int(c.Uint("vision.patch_size")),
|
||||||
|
numChannels: int(c.Uint("vision.num_channels")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
|
||||||
|
var pixelVals, rVals, gVals, bVals []float32
|
||||||
|
|
||||||
|
bounds := img.Bounds()
|
||||||
|
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||||
|
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||||
|
c := img.At(x, y)
|
||||||
|
r, g, b, _ := c.RGBA()
|
||||||
|
rVal := float32(r>>8) / 255.0
|
||||||
|
gVal := float32(g>>8) / 255.0
|
||||||
|
bVal := float32(b>>8) / 255.0
|
||||||
|
|
||||||
|
rVal = (rVal - mean[0]) / std[0]
|
||||||
|
gVal = (gVal - mean[1]) / std[1]
|
||||||
|
bVal = (bVal - mean[2]) / std[2]
|
||||||
|
|
||||||
|
rVals = append(rVals, rVal)
|
||||||
|
gVals = append(gVals, gVal)
|
||||||
|
bVals = append(bVals, bVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pixelVals = append(pixelVals, rVals...)
|
||||||
|
pixelVals = append(pixelVals, gVals...)
|
||||||
|
pixelVals = append(pixelVals, bVals...)
|
||||||
|
|
||||||
|
return pixelVals
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||||
|
outputSize := image.Point{p.imageSize, p.imageSize}
|
||||||
|
newImage := imageproc.Composite(img)
|
||||||
|
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
||||||
|
|
||||||
|
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
3
x/model/models/models.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
// _ "github.com/ollama/ollama/x/model/models/gemma3"
|
||||||
249
x/model/sentencepiece.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/heap"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
|
type SentencePiece struct {
|
||||||
|
maxTokenLen int
|
||||||
|
vocab *Vocabulary
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TextProcessor = (*SentencePiece)(nil)
|
||||||
|
|
||||||
|
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||||
|
return spm.vocab
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||||
|
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
|
counter := map[int]int{}
|
||||||
|
var maxTokenLen int
|
||||||
|
for cnt := range vocab.Types {
|
||||||
|
switch vocab.Types[cnt] {
|
||||||
|
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
||||||
|
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
counter[int(vocab.Types[cnt])] += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||||
|
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||||
|
"max token len", maxTokenLen)
|
||||||
|
|
||||||
|
return SentencePiece{
|
||||||
|
maxTokenLen: maxTokenLen,
|
||||||
|
vocab: vocab,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||||
|
return spm.vocab.Is(id, special)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
|
fragments := []fragment{{value: s}}
|
||||||
|
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||||
|
id := spm.vocab.Encode(special)
|
||||||
|
for i := 0; i < len(fragments); i++ {
|
||||||
|
frag := fragments[i]
|
||||||
|
if len(frag.ids) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var middle []fragment
|
||||||
|
switch i := strings.Index(frag.value, special); {
|
||||||
|
case i < 0:
|
||||||
|
middle = append(middle, frag)
|
||||||
|
case i > 0:
|
||||||
|
middle = append(middle, fragment{value: frag.value[:i]})
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||||
|
if rest := frag.value[i+len(special):]; rest != "" {
|
||||||
|
middle = append(middle, fragment{value: rest})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var ids []int32
|
||||||
|
for _, frag := range fragments {
|
||||||
|
if len(frag.ids) > 0 {
|
||||||
|
ids = append(ids, frag.ids...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
|
||||||
|
|
||||||
|
if id := spm.vocab.Encode(text); id >= 0 {
|
||||||
|
ids = append(ids, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
q := &queue{}
|
||||||
|
heap.Init(q)
|
||||||
|
|
||||||
|
runes := []rune(text)
|
||||||
|
merges := make([]merge, len(runes))
|
||||||
|
for r := range runes {
|
||||||
|
merges[r] = merge{
|
||||||
|
p: r - 1,
|
||||||
|
n: r + 1,
|
||||||
|
runes: []rune{runes[r]},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pairwise := func(a, b int) *candidate {
|
||||||
|
if a < 0 || b >= len(runes) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||||
|
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||||
|
return &candidate{
|
||||||
|
a: a,
|
||||||
|
b: b,
|
||||||
|
score: spm.vocab.Scores[id],
|
||||||
|
size: len(left) + len(right),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range len(runes) - 1 {
|
||||||
|
if pair := pairwise(i, i+1); pair != nil {
|
||||||
|
heap.Push(q, pair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for q.Len() > 0 {
|
||||||
|
pair := heap.Pop(q).(*candidate)
|
||||||
|
left, right := merges[pair.a], merges[pair.b]
|
||||||
|
|
||||||
|
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||||
|
merges[pair.b].runes = nil
|
||||||
|
merges[pair.a].n = right.n
|
||||||
|
if right.n < len(merges) {
|
||||||
|
merges[right.n].p = pair.a
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||||
|
heap.Push(q, pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||||
|
heap.Push(q, pair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, merge := range merges {
|
||||||
|
if token := string(merge.runes); token != "" {
|
||||||
|
id := spm.vocab.Encode(token)
|
||||||
|
|
||||||
|
if id >= 0 {
|
||||||
|
ids = append(ids, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to byte tokenization
|
||||||
|
var result []int32
|
||||||
|
for _, b := range []byte(token) {
|
||||||
|
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||||
|
unknownID := spm.vocab.Encode(byteToken)
|
||||||
|
if unknownID >= 0 {
|
||||||
|
result = append(result, unknownID)
|
||||||
|
} else {
|
||||||
|
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ids = append(ids, result...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if addSpecial {
|
||||||
|
ids = spm.vocab.addSpecials(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type candidate struct {
|
||||||
|
a, b int
|
||||||
|
score float32
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
type queue []*candidate
|
||||||
|
|
||||||
|
func (q queue) Len() int { return len(q) }
|
||||||
|
|
||||||
|
func (q queue) Less(i, j int) bool {
|
||||||
|
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||||
|
|
||||||
|
func (q *queue) Push(x interface{}) {
|
||||||
|
item := x.(*candidate)
|
||||||
|
*q = append(*q, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) Pop() interface{} {
|
||||||
|
old := *q
|
||||||
|
n := len(old)
|
||||||
|
item := old[n-1]
|
||||||
|
*q = old[0 : n-1]
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
|
||||||
|
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, id := range ids {
|
||||||
|
data := spm.vocab.Decode(id)
|
||||||
|
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||||
|
|
||||||
|
// For tokenizers that use byte tokens like "<0xEA>"
|
||||||
|
// convert them to the partial unicode character
|
||||||
|
// so they are buffered correctly by the runner instead
|
||||||
|
// of being sent back to the api as "<0xEA>"
|
||||||
|
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||||
|
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := sb.WriteString(data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("decoded", "ids", ids, "string", sb.String())
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
172
x/model/sentencepiece_test.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/convert/sentencepiece"
|
||||||
|
)
|
||||||
|
|
||||||
|
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var spm sentencepiece.ModelProto
|
||||||
|
if err := proto.Unmarshal(bts, &spm); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var v Vocabulary
|
||||||
|
|
||||||
|
for _, piece := range spm.GetPieces() {
|
||||||
|
v.Values = append(v.Values, piece.GetPiece())
|
||||||
|
v.Scores = append(v.Scores, piece.GetScore())
|
||||||
|
switch t := piece.GetType(); t {
|
||||||
|
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
|
||||||
|
sentencepiece.ModelProto_SentencePiece_CONTROL,
|
||||||
|
sentencepiece.ModelProto_SentencePiece_UNUSED,
|
||||||
|
sentencepiece.ModelProto_SentencePiece_BYTE:
|
||||||
|
v.Types = append(v.Types, int32(t))
|
||||||
|
default:
|
||||||
|
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||||
|
// todo parse the special tokens file
|
||||||
|
// - this will roundtrip correctly but the <start_of_turn> and
|
||||||
|
// <end_of_turn> tokens aren't processed
|
||||||
|
v.Types = append(v.Types, tt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewSentencePiece(&v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSentencePieceEncode(t *testing.T) {
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||||
|
slog.SetDefault(logger)
|
||||||
|
|
||||||
|
tokenizer := loadSentencePieceVocab(t)
|
||||||
|
|
||||||
|
t.Run("basic roundtrip", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := []string{
|
||||||
|
"hello",
|
||||||
|
"hello ",
|
||||||
|
"hello ",
|
||||||
|
" hello",
|
||||||
|
" hello ",
|
||||||
|
" hello ",
|
||||||
|
"hello world",
|
||||||
|
"请考试我的软件!12345",
|
||||||
|
"你好",
|
||||||
|
"Hello 你好 world!",
|
||||||
|
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||||
|
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
|
||||||
|
"Numbers and symbols: 123456789 +- */",
|
||||||
|
"Special tokens: <bos> text <eos>",
|
||||||
|
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
|
||||||
|
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
|
||||||
|
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
|
||||||
|
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range cases {
|
||||||
|
ids, err := tokenizer.Encode(want, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, err := tokenizer.Decode(ids); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if got != want {
|
||||||
|
t.Errorf("got %q, want %q [%#v]", got, want, ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("special tokens", func(t *testing.T) {
|
||||||
|
type candidate struct {
|
||||||
|
token string
|
||||||
|
ids []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []candidate{
|
||||||
|
{"<bos>", []int32{2}},
|
||||||
|
{"<eos>", []int32{1}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range cases {
|
||||||
|
ids, err := tokenizer.Encode(want.token, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !slices.Equal(ids, want.ids) {
|
||||||
|
t.Errorf("got %#v, want %#v", ids, want.ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||||
|
vocab := &Vocabulary{
|
||||||
|
Values: []string{
|
||||||
|
"normal",
|
||||||
|
"<0xEA>",
|
||||||
|
"<0x41>",
|
||||||
|
"<0xC3>",
|
||||||
|
"<0xA3>",
|
||||||
|
},
|
||||||
|
Types: []int32{
|
||||||
|
TOKEN_TYPE_NORMAL,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
},
|
||||||
|
Scores: []float32{0, 0, 0, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
spm := NewSentencePiece(vocab)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ids []int32
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single byte token",
|
||||||
|
ids: []int32{1},
|
||||||
|
expected: "\xea",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ASCII byte token",
|
||||||
|
ids: []int32{2},
|
||||||
|
expected: "A",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple byte tokens forming UTF-8 character",
|
||||||
|
ids: []int32{3, 4},
|
||||||
|
expected: "ã",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := spm.Decode(tt.ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("got %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
17
x/model/textprocessor.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
const (
|
||||||
|
TOKEN_TYPE_NORMAL = iota + 1
|
||||||
|
TOKEN_TYPE_UNKNOWN
|
||||||
|
TOKEN_TYPE_CONTROL
|
||||||
|
TOKEN_TYPE_USER_DEFINED
|
||||||
|
TOKEN_TYPE_UNUSED
|
||||||
|
TOKEN_TYPE_BYTE
|
||||||
|
)
|
||||||
|
|
||||||
|
type TextProcessor interface {
|
||||||
|
Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
|
Decode([]int32) (string, error)
|
||||||
|
Is(int32, Special) bool
|
||||||
|
Vocabulary() *Vocabulary
|
||||||
|
}
|
||||||
112
x/model/vocabulary.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Special int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
SpecialBOS Special = iota
|
||||||
|
SpecialEOS
|
||||||
|
)
|
||||||
|
|
||||||
|
type Vocabulary struct {
|
||||||
|
Values []string
|
||||||
|
Types []int32
|
||||||
|
Scores []float32
|
||||||
|
Merges []string
|
||||||
|
|
||||||
|
BOS, EOS []int32
|
||||||
|
AddBOS, AddEOS bool
|
||||||
|
|
||||||
|
specialOnce sync.Once
|
||||||
|
special []string
|
||||||
|
|
||||||
|
valuesOnce sync.Once
|
||||||
|
values map[string]int32
|
||||||
|
|
||||||
|
mergeOnce sync.Once
|
||||||
|
merge map[string]int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) Is(id int32, special Special) bool {
|
||||||
|
switch special {
|
||||||
|
case SpecialBOS:
|
||||||
|
return slices.Contains(v.BOS, id)
|
||||||
|
case SpecialEOS:
|
||||||
|
return slices.Contains(v.EOS, id)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||||
|
if v.AddBOS && len(v.BOS) > 0 {
|
||||||
|
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
|
||||||
|
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
|
||||||
|
ids = append([]int32{v.BOS[0]}, ids...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.AddEOS && len(v.EOS) > 0 {
|
||||||
|
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
|
||||||
|
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
|
||||||
|
ids = append(ids, v.EOS[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) Encode(s string) int32 {
|
||||||
|
v.valuesOnce.Do(func() {
|
||||||
|
v.values = make(map[string]int32, len(v.Values))
|
||||||
|
for i, value := range v.Values {
|
||||||
|
v.values[value] = int32(i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if id, ok := v.values[s]; ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) Decode(id int32) string {
|
||||||
|
return v.Values[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||||
|
v.specialOnce.Do(func() {
|
||||||
|
for i := range v.Values {
|
||||||
|
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
|
||||||
|
v.special = append(v.special, v.Values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return v.special
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) Merge(left, right string) int {
|
||||||
|
v.mergeOnce.Do(func() {
|
||||||
|
v.merge = make(map[string]int32, len(v.Merges))
|
||||||
|
for i, merge := range v.Merges {
|
||||||
|
v.merge[merge] = int32(i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if id, ok := v.merge[left+" "+right]; ok {
|
||||||
|
return int(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1
|
||||||
|
}
|
||||||
107
x/model/vocabulary_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSpecialVocabulary(t *testing.T) {
|
||||||
|
vocab := &Vocabulary{
|
||||||
|
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
||||||
|
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
||||||
|
}
|
||||||
|
|
||||||
|
specialVocab := vocab.SpecialVocabulary()
|
||||||
|
|
||||||
|
if len(specialVocab) != 4 {
|
||||||
|
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddSpecialVocabulary(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
vocab *Vocabulary
|
||||||
|
input []int32
|
||||||
|
want []int32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "add bos",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: false,
|
||||||
|
},
|
||||||
|
input: []int32{2, 3, 4},
|
||||||
|
want: []int32{0, 2, 3, 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// TODO(mxyng): this is to match previous behaviour
|
||||||
|
name: "add bos when already present",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: false,
|
||||||
|
},
|
||||||
|
input: []int32{0, 2, 3, 4},
|
||||||
|
want: []int32{0, 0, 2, 3, 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add eos",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: false,
|
||||||
|
AddEOS: true,
|
||||||
|
},
|
||||||
|
input: []int32{2, 3, 4},
|
||||||
|
want: []int32{2, 3, 4, 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// TODO(mxyng): this is to match previous behaviour
|
||||||
|
name: "add eos when already present",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: false,
|
||||||
|
AddEOS: true,
|
||||||
|
},
|
||||||
|
input: []int32{2, 3, 4, 1},
|
||||||
|
want: []int32{2, 3, 4, 1, 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add both",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: true,
|
||||||
|
},
|
||||||
|
input: []int32{2, 3, 4},
|
||||||
|
want: []int32{0, 2, 3, 4, 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add bos to empty inputs",
|
||||||
|
vocab: &Vocabulary{
|
||||||
|
BOS: []int32{0},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: false,
|
||||||
|
},
|
||||||
|
input: []int32{},
|
||||||
|
want: []int32{0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.vocab.addSpecials(tt.input)
|
||||||
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
t.Errorf("no match (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
171
x/model/wordpiece.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WordPiece struct {
|
||||||
|
vocab *Vocabulary
|
||||||
|
lowercase bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||||
|
// this differs from original word piece which uses "##" to indicate subwords.
|
||||||
|
const ggmlPrefix = "▁"
|
||||||
|
|
||||||
|
var wordPieceReplacer = strings.NewReplacer(
|
||||||
|
" .", ".",
|
||||||
|
" ?", "?",
|
||||||
|
" !", "!",
|
||||||
|
" ,", ",",
|
||||||
|
" ' ", "'",
|
||||||
|
" n't", "n't",
|
||||||
|
" 'm", "'m",
|
||||||
|
" do not", " don't",
|
||||||
|
" 's", "'s",
|
||||||
|
" 've", "'ve",
|
||||||
|
" 're", "'re",
|
||||||
|
)
|
||||||
|
|
||||||
|
// Decode implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for i, id := range ids {
|
||||||
|
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||||
|
return "", fmt.Errorf("invalid token id: %d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
var separator string
|
||||||
|
piece := wpm.vocab.Values[id]
|
||||||
|
if i > 0 &&
|
||||||
|
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||||
|
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||||
|
separator = " "
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// words splits a string into words, treating CJK characters as separate words.
|
||||||
|
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||||
|
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||||
|
return func(yield func(string) bool) {
|
||||||
|
runes := make([]rune, 0, len(s)*3)
|
||||||
|
for _, r := range s {
|
||||||
|
switch {
|
||||||
|
case r >= 0x4E00 && r <= 0x9FFF,
|
||||||
|
r >= 0x3400 && r <= 0x4DBF,
|
||||||
|
r >= 0x20000 && r <= 0x2A6DF,
|
||||||
|
r >= 0x2A700 && r <= 0x2B73F,
|
||||||
|
r >= 0x2B740 && r <= 0x2B81F,
|
||||||
|
r >= 0x2B820 && r <= 0x2CEAF,
|
||||||
|
r >= 0xF900 && r <= 0xFAFF,
|
||||||
|
r >= 0x2F800 && r <= 0x2FA1F:
|
||||||
|
runes = append(runes, ' ', r, ' ')
|
||||||
|
default:
|
||||||
|
runes = append(runes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||||
|
// split on but keep punctuation
|
||||||
|
var start int
|
||||||
|
for start < len(w) {
|
||||||
|
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||||
|
if end < 0 {
|
||||||
|
end = len(w) - start
|
||||||
|
} else if end == 0 {
|
||||||
|
end = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(w[start : start+end]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
start += end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
|
var ids []int32
|
||||||
|
|
||||||
|
// TODO: use [UNK] from config
|
||||||
|
unk := wpm.vocab.Encode("[UNK]")
|
||||||
|
for word := range wpm.words(s) {
|
||||||
|
var start int
|
||||||
|
var pieces []int32
|
||||||
|
for start < len(word) {
|
||||||
|
end := len(word)
|
||||||
|
|
||||||
|
var piece int32
|
||||||
|
for start < end {
|
||||||
|
subword := word[start:end]
|
||||||
|
if start == 0 {
|
||||||
|
subword = ggmlPrefix + subword
|
||||||
|
}
|
||||||
|
|
||||||
|
if wpm.lowercase {
|
||||||
|
subword = strings.ToLower(subword)
|
||||||
|
}
|
||||||
|
piece = wpm.vocab.Encode(subword)
|
||||||
|
if piece >= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
end--
|
||||||
|
}
|
||||||
|
|
||||||
|
if piece < 0 {
|
||||||
|
// Unknown token
|
||||||
|
pieces = pieces[:0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
pieces = append(pieces, piece)
|
||||||
|
start = end
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pieces) > 0 {
|
||||||
|
ids = append(ids, pieces...)
|
||||||
|
} else {
|
||||||
|
ids = append(ids, unk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if addSpecial {
|
||||||
|
ids = wpm.vocab.addSpecials(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||||
|
return wpm.vocab.Is(id, special)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vocabulary implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||||
|
return wpm.vocab
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TextProcessor = (*WordPiece)(nil)
|
||||||
|
|
||||||
|
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||||
|
return WordPiece{
|
||||||
|
vocab: vocab,
|
||||||
|
lowercase: lowercase,
|
||||||
|
}
|
||||||
|
}
|
||||||
53
x/model/wordpiece_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWordPiece(t *testing.T) {
|
||||||
|
wpm := NewWordPiece(
|
||||||
|
&Vocabulary{
|
||||||
|
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: true,
|
||||||
|
BOS: []int32{1},
|
||||||
|
EOS: []int32{2},
|
||||||
|
},
|
||||||
|
true, // lowercase
|
||||||
|
)
|
||||||
|
|
||||||
|
ids, err := wpm.Encode("Hello world!", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||||
|
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
words, err := wpm.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWordPieceWords(t *testing.T) {
|
||||||
|
var wpm WordPiece
|
||||||
|
|
||||||
|
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||||
|
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||||
|
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||