Compare commits
32 Commits
mlx-gpu-cd
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b220bfa0b9 | ||
|
|
c23d5095de | ||
|
|
7601f0e93e | ||
|
|
aad3f03890 | ||
|
|
55d0b6e8b9 | ||
|
|
38eac40d56 | ||
|
|
80f3f1bc25 | ||
|
|
b1a0db547b | ||
|
|
75d7b5f926 | ||
|
|
349d814814 | ||
|
|
c8743031e0 | ||
|
|
4adb9cf4bb | ||
|
|
74f475e735 | ||
|
|
875cecba74 | ||
|
|
7d411a4686 | ||
|
|
02a2401596 | ||
|
|
e4b488a7b5 | ||
|
|
98079ddd79 | ||
|
|
d70942f47b | ||
|
|
58e4701557 | ||
|
|
dbf47ee55a | ||
|
|
af7ea6e96e | ||
|
|
8f1e0140e7 | ||
|
|
35c3c9e3c2 | ||
|
|
d06acbcb19 | ||
|
|
9667c2282f | ||
|
|
a937a68317 | ||
|
|
2185112d84 | ||
|
|
91926601dc | ||
|
|
361d6c16c2 | ||
|
|
7e2496e88e | ||
|
|
5b84e29882 |
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -13,7 +13,7 @@ body:
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant log output
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
|
||||
6
.github/workflows/release.yaml
vendored
@@ -372,13 +372,17 @@ jobs:
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- name: Deduplicate CUDA libraries
|
||||
run: |
|
||||
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
endif()
|
||||
|
||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
if(APPLE)
|
||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
@@ -189,13 +190,21 @@ if(MLX_ENGINE)
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
)
|
||||
|
||||
# Install the Metal library for macOS arm64 (must be colocated with the binary)
|
||||
# Metal backend is only built for arm64, not x86_64
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
|
||||
@@ -161,6 +161,9 @@ ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN mkdir -p dist/bin
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -182,6 +185,7 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
43
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
||||
|
||||
## Model library
|
||||
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
||||
|
||||
Here are some example models that can be downloaded:
|
||||
|
||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
## Building with MLX (experimental)
|
||||
|
||||
First build the MLX libraries:
|
||||
|
||||
```shell
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
|
||||
```shell
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
```
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
### Building MLX with CUDA
|
||||
|
||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
@@ -421,7 +454,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxBufferSize = 8 * format.MegaByte
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf io.Reader
|
||||
|
||||
@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
|
||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||
@property(assign, nonatomic) BOOL updateAvailable;
|
||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
||||
@end
|
||||
|
||||
@implementation AppDelegate
|
||||
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||
// Register for system shutdown/restart notification so we can allow termination
|
||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
||||
addObserver:self
|
||||
selector:@selector(systemWillPowerOff:)
|
||||
name:NSWorkspaceWillPowerOffNotification
|
||||
object:nil];
|
||||
|
||||
// if we're in development mode, set the app icon
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
if (![bundlePath hasSuffix:@".app"]) {
|
||||
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
}
|
||||
|
||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
||||
// The system will call applicationShouldTerminate: after posting this notification.
|
||||
self.systemShutdownInProgress = YES;
|
||||
}
|
||||
|
||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||
// Allow termination if the system is shutting down or restarting
|
||||
if (self.systemShutdownInProgress) {
|
||||
return NSTerminateNow;
|
||||
}
|
||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
||||
[NSApp hide:nil];
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||
return NSTerminateCancel;
|
||||
|
||||
34
cmd/cmd.go
@@ -100,7 +100,8 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
return imagegenclient.CreateModel(args[0], ".", p)
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return imagegenclient.CreateModel(args[0], ".", quantize, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
@@ -464,14 +465,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
name := args[0]
|
||||
|
||||
// Check if this is a known image generation model (skip Show/Pull)
|
||||
if imagegen.HasTensorLayers(name) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -533,9 +526,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -565,7 +567,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
@@ -671,7 +673,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
msg := resp.Status
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
|
||||
}
|
||||
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@@ -837,11 +843,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check if this is an image generation model
|
||||
if imagegen.HasTensorLayers(args[0]) {
|
||||
return imagegen.Show(args[0], os.Stdout)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1786,6 +1787,7 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowInfoImageGen(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" image \n" +
|
||||
"\n"
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushProgressMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
digest string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "uses status when provided",
|
||||
status: "uploading model",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "uploading model",
|
||||
},
|
||||
{
|
||||
name: "falls back to digest when status empty",
|
||||
status: "",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "pushing abc123456789...",
|
||||
},
|
||||
{
|
||||
name: "handles short digest gracefully",
|
||||
status: "",
|
||||
digest: "sha256:abc",
|
||||
wantMsg: "pushing sha256:abc...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := tt.status
|
||||
if msg == "" {
|
||||
if len(tt.digest) >= 19 {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
|
||||
} else {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest)
|
||||
}
|
||||
}
|
||||
if msg != tt.wantMsg {
|
||||
t.Errorf("got %q, want %q", msg, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
// Test that modifications to original don't affect copy
|
||||
originalThink := &api.ThinkValue{Value: "original"}
|
||||
|
||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -21,6 +21,7 @@ ollama pull glm-4.7:cloud
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
@@ -247,12 +248,13 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const results = await client.webSearch({ query: "what is ollama?" });
|
||||
const results = await client.webSearch("what is ollama?");
|
||||
console.log(JSON.stringify(results, null, 2));
|
||||
```
|
||||
|
||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||
const fetchResult = await client.webFetch("https://ollama.com");
|
||||
console.log(JSON.stringify(fetchResult, null, 2));
|
||||
```
|
||||
|
||||
|
||||
@@ -111,7 +111,9 @@
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode"
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
|
||||
BIN
docs/images/marimo-add-model.png
Normal file
|
After Width: | Height: | Size: 174 KiB |
BIN
docs/images/marimo-chat.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
BIN
docs/images/marimo-code-completion.png
Normal file
|
After Width: | Height: | Size: 230 KiB |
BIN
docs/images/marimo-models.png
Normal file
|
After Width: | Height: | Size: 178 KiB |
BIN
docs/images/marimo-settings.png
Normal file
|
After Width: | Height: | Size: 186 KiB |
BIN
docs/images/onyx-login.png
Normal file
|
After Width: | Height: | Size: 100 KiB |
BIN
docs/images/onyx-ollama-form.png
Normal file
|
After Width: | Height: | Size: 306 KiB |
BIN
docs/images/onyx-ollama-llm.png
Normal file
|
After Width: | Height: | Size: 300 KiB |
BIN
docs/images/onyx-query.png
Normal file
|
After Width: | Height: | Size: 211 KiB |
@@ -25,6 +25,7 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
@@ -38,7 +39,7 @@ claude --model qwen3-coder
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
73
docs/integrations/marimo.mdx
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: marimo
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
||||
|
||||
```
|
||||
uvx marimo edit --sandbox notebook.py
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
||||
you can find and configure Ollama as an AI provider. For local use you
|
||||
would typically point the base url to `http://localhost:11434/v1`.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-settings.png"
|
||||
alt="Ollama settings in marimo"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-models.png"
|
||||
alt="Selecting an Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-add-model.png"
|
||||
alt="Adding a new Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-chat.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-code-completion.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Sign in to ollama cloud via `ollama signin`
|
||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
||||
3. You can now refer to this model in marimo!
|
||||
63
docs/integrations/onyx.mdx
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
title: Onyx
|
||||
---
|
||||
|
||||
## Overview
|
||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
||||
- Creating custom Agents
|
||||
- Web search
|
||||
- Deep Research
|
||||
- RAG over uploaded documents and connected apps
|
||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
||||
- MCP and OpenAPI Actions support
|
||||
- Image generation
|
||||
- User/Groups management, RBAC, SSO, etc.
|
||||
|
||||
Onyx can be deployed for single users or large organizations.
|
||||
|
||||
## Install Onyx
|
||||
|
||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
||||
|
||||
<Info>
|
||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
||||
</Info>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. Login to your Onyx deployment (create an account first).
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-login.png"
|
||||
alt="Onyx Login Page"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. In the set-up process select `Ollama` as the LLM provider.
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-llm.png"
|
||||
alt="Onyx Set Up Form"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Provide your **Ollama API URL** and select your models.
|
||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-form.png"
|
||||
alt="Selecting Ollama Models"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
||||
|
||||
## Send your first query
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-query.png"
|
||||
alt="Onyx Query Example"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Linux"
|
||||
title: Linux
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -40,8 +41,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -112,7 +113,11 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# Troubleshooting
|
||||
|
||||
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)
|
||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -1464,6 +1464,11 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1512,6 +1517,11 @@ type CompletionResponse struct {
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image generation fields
|
||||
Image []byte `json:"image,omitempty"` // Generated image
|
||||
Step int `json:"step,omitempty"` // Current generation step
|
||||
Total int `json:"total,omitempty"` // Total generation steps
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -118,6 +118,9 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
c.Set("relax_thinking", true)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
|
||||
@@ -582,3 +582,26 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var flagSet bool
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
_, flagSet = c.Get("relax_thinking")
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if !flagSet {
|
||||
t.Error("expected relax_thinking flag to be set in context")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -441,6 +442,7 @@ type ResponsesWriter struct {
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
@@ -478,7 +480,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -523,11 +527,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
|
||||
@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -265,9 +266,9 @@ type ResponsesText struct {
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Description *string `json:"description"` // nullable but required
|
||||
Strict *bool `json:"strict"` // nullable but required
|
||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
||||
}
|
||||
|
||||
type ResponsesRequest struct {
|
||||
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var description string
|
||||
if t.Description != nil {
|
||||
description = *t.Description
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: t.Type,
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
||||
|
||||
// Response types for the Responses API
|
||||
|
||||
// ResponsesTextField represents the text output configuration in the response.
|
||||
type ResponsesTextField struct {
|
||||
Format ResponsesTextFormat `json:"format"`
|
||||
}
|
||||
|
||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
||||
type ResponsesReasoningOutput struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Summary *string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError represents an error in the response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||
// requires additional plumbing to find the effective values since the
|
||||
// defaults can come from the model or the request
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at"`
|
||||
Status string `json:"status"`
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
||||
Model string `json:"model"`
|
||||
PreviousResponseID *string `json:"previous_response_id"`
|
||||
Instructions *string `json:"instructions"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Error *ResponsesError `json:"error"`
|
||||
Tools []ResponsesTool `json:"tools"`
|
||||
ToolChoice any `json:"tool_choice"`
|
||||
Truncation string `json:"truncation"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
Text ResponsesTextField `json:"text"`
|
||||
TopP float64 `json:"top_p"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
||||
Usage *ResponsesUsage `json:"usage"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
MaxToolCalls *int `json:"max_tool_calls"`
|
||||
Store bool `json:"store"`
|
||||
Background bool `json:"background"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
SafetyIdentifier *string `json:"safety_identifier"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
||||
}
|
||||
|
||||
type ResponsesOutputItem struct {
|
||||
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Annotations []any `json:"annotations"`
|
||||
Logprobs []any `json:"logprobs"`
|
||||
}
|
||||
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
||||
func derefFloat64(p *float64, def float64) float64 {
|
||||
if p != nil {
|
||||
return *p
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
||||
// The request is used to echo back request parameters in the response.
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
||||
var output []ResponsesOutputItem
|
||||
|
||||
// Add reasoning item if thinking is present
|
||||
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
output = append(output, ResponsesOutputItem{
|
||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
Role: "assistant",
|
||||
Content: []ResponsesOutputContent{
|
||||
{
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Annotations: []any{},
|
||||
Logprobs: []any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var instructions *string
|
||||
if request.Instructions != "" {
|
||||
instructions = &request.Instructions
|
||||
}
|
||||
|
||||
// Build truncation with default
|
||||
truncation := "disabled"
|
||||
if request.Truncation != nil {
|
||||
truncation = *request.Truncation
|
||||
}
|
||||
|
||||
tools := request.Tools
|
||||
if tools == nil {
|
||||
tools = []ResponsesTool{}
|
||||
}
|
||||
|
||||
text := ResponsesTextField{
|
||||
Format: ResponsesTextFormat{Type: "text"},
|
||||
}
|
||||
if request.Text != nil && request.Text.Format != nil {
|
||||
text.Format = *request.Text.Format
|
||||
}
|
||||
|
||||
// Build reasoning output from request
|
||||
var reasoning *ResponsesReasoningOutput
|
||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
||||
reasoning = &ResponsesReasoningOutput{}
|
||||
if request.Reasoning.Effort != "" {
|
||||
reasoning.Effort = &request.Reasoning.Effort
|
||||
}
|
||||
if request.Reasoning.Summary != "" {
|
||||
reasoning.Summary = &request.Reasoning.Summary
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesResponse{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
Status: "completed",
|
||||
Model: model,
|
||||
Output: output,
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
CompletedAt: nil, // Set by middleware when writing final response
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil, // Only populated if response incomplete
|
||||
Model: model,
|
||||
PreviousResponseID: nil, // Not supported
|
||||
Instructions: instructions,
|
||||
Output: output,
|
||||
Error: nil, // Only populated on failure
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // Default value
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: true, // Default value
|
||||
Text: text,
|
||||
TopP: derefFloat64(request.TopP, 1.0),
|
||||
PresencePenalty: 0, // Default value
|
||||
FrequencyPenalty: 0, // Default value
|
||||
TopLogprobs: 0, // Default value
|
||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
||||
Reasoning: reasoning,
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: chatResponse.PromptEvalCount,
|
||||
OutputTokens: chatResponse.EvalCount,
|
||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||
// TODO(drifkin): wire through the actual values
|
||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
||||
// TODO(drifkin): wire through the actual values
|
||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
||||
},
|
||||
MaxOutputTokens: request.MaxOutputTokens,
|
||||
MaxToolCalls: nil, // Not supported
|
||||
Store: false, // We don't store responses
|
||||
Background: request.Background,
|
||||
ServiceTier: "default", // Default value
|
||||
Metadata: map[string]any{},
|
||||
SafetyIdentifier: nil, // Not supported
|
||||
PromptCacheKey: nil, // Not supported
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
|
||||
responseID string
|
||||
itemID string
|
||||
model string
|
||||
request ResponsesRequest
|
||||
|
||||
// State tracking (mutated across Process calls)
|
||||
firstWrite bool
|
||||
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
||||
}
|
||||
|
||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
||||
return &ResponsesStreamConverter{
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
model: model,
|
||||
request: request,
|
||||
firstWrite: true,
|
||||
}
|
||||
}
|
||||
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
||||
return events
|
||||
}
|
||||
|
||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
||||
var instructions any = nil
|
||||
if c.request.Instructions != "" {
|
||||
instructions = c.request.Instructions
|
||||
}
|
||||
|
||||
truncation := "disabled"
|
||||
if c.request.Truncation != nil {
|
||||
truncation = *c.request.Truncation
|
||||
}
|
||||
|
||||
var tools []any
|
||||
if c.request.Tools != nil {
|
||||
for _, t := range c.request.Tools {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"strict": t.Strict,
|
||||
"parameters": t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if tools == nil {
|
||||
tools = []any{}
|
||||
}
|
||||
|
||||
textFormat := map[string]any{"type": "text"}
|
||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
||||
textFormat = map[string]any{
|
||||
"type": c.request.Text.Format.Type,
|
||||
}
|
||||
if c.request.Text.Format.Name != "" {
|
||||
textFormat["name"] = c.request.Text.Format.Name
|
||||
}
|
||||
if c.request.Text.Format.Schema != nil {
|
||||
textFormat["schema"] = c.request.Text.Format.Schema
|
||||
}
|
||||
if c.request.Text.Format.Strict != nil {
|
||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
||||
}
|
||||
}
|
||||
|
||||
var reasoning any = nil
|
||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
||||
r := map[string]any{}
|
||||
if c.request.Reasoning.Effort != "" {
|
||||
r["effort"] = c.request.Reasoning.Effort
|
||||
} else {
|
||||
r["effort"] = nil
|
||||
}
|
||||
if c.request.Reasoning.Summary != "" {
|
||||
r["summary"] = c.request.Reasoning.Summary
|
||||
} else {
|
||||
r["summary"] = nil
|
||||
}
|
||||
reasoning = r
|
||||
}
|
||||
|
||||
// Build top_p and temperature with defaults
|
||||
topP := 1.0
|
||||
if c.request.TopP != nil {
|
||||
topP = *c.request.TopP
|
||||
}
|
||||
temperature := 1.0
|
||||
if c.request.Temperature != nil {
|
||||
temperature = *c.request.Temperature
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"completed_at": nil,
|
||||
"status": status,
|
||||
"incomplete_details": nil,
|
||||
"model": c.model,
|
||||
"previous_response_id": nil,
|
||||
"instructions": instructions,
|
||||
"output": output,
|
||||
"error": nil,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"truncation": truncation,
|
||||
"parallel_tool_calls": true,
|
||||
"text": map[string]any{"format": textFormat},
|
||||
"top_p": topP,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"top_logprobs": 0,
|
||||
"temperature": temperature,
|
||||
"reasoning": reasoning,
|
||||
"usage": usage,
|
||||
"max_output_tokens": c.request.MaxOutputTokens,
|
||||
"max_tool_calls": nil,
|
||||
"store": false,
|
||||
"background": c.request.Background,
|
||||
"service_tier": "default",
|
||||
"metadata": map[string]any{},
|
||||
"safety_identifier": nil,
|
||||
"prompt_cache_key": nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.created", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.in_progress", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
||||
|
||||
// Emit delta
|
||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"delta": thinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"delta": thinking,
|
||||
}))
|
||||
|
||||
// TODO(drifkin): consider adding
|
||||
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
||||
|
||||
events := []ResponsesStreamEvent{
|
||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"text": c.accumulatedThinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"text": c.accumulatedThinking,
|
||||
}),
|
||||
c.newEvent("response.output_item.done", map[string]any{
|
||||
"output_index": c.outputIndex,
|
||||
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": c.contentIndex,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"delta": content,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
return events
|
||||
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
})
|
||||
}
|
||||
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"text": c.accumulatedText,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
// response.content_part.done
|
||||
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// response.completed
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": c.buildFinalOutput(),
|
||||
"usage": map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
usage := map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
"output_tokens_details": map[string]any{
|
||||
"reasoning_tokens": 0,
|
||||
},
|
||||
}
|
||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
||||
response["completed_at"] = time.Now().Unix()
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": response,
|
||||
}))
|
||||
|
||||
return events
|
||||
|
||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with content
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with thinking
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
||||
Content: "The answer is 42",
|
||||
},
|
||||
Done: true,
|
||||
})
|
||||
}, ResponsesRequest{})
|
||||
|
||||
// Should have 2 output items: reasoning + message
|
||||
if len(response.Output) != 2 {
|
||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
// Verify that response.output_item.done includes content field for messages
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.completed includes the output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// Process some content
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.created includes an empty output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hi"},
|
||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
// Verify that events include incrementing sequence numbers
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hello"},
|
||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
// Verify that function call items include status field
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -36,10 +37,11 @@ type Terminal struct {
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
pastedLines []string
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
i.pastedLines = nil
|
||||
i.Prompt.UseAlt = false
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
buf.Remove()
|
||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
||||
lastIdx := len(i.pastedLines) - 1
|
||||
prevLine := i.pastedLines[lastIdx]
|
||||
i.pastedLines = i.pastedLines[:lastIdx]
|
||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
||||
if len(i.pastedLines) == 0 {
|
||||
fmt.Print(i.Prompt.Prompt)
|
||||
i.Prompt.UseAlt = false
|
||||
} else {
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
}
|
||||
for _, r := range prevLine {
|
||||
buf.Add(r)
|
||||
}
|
||||
} else {
|
||||
buf.Remove()
|
||||
}
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharEnter, CharCtrlJ:
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
||||
i.pastedLines = nil
|
||||
}
|
||||
if output != "" {
|
||||
i.History.Add(output)
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
i.Prompt.UseAlt = false
|
||||
|
||||
return output, nil
|
||||
default:
|
||||
|
||||
@@ -73,7 +73,7 @@ _build_darwin() {
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
@@ -82,19 +82,19 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/imagegen
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
# create a temporary zip for notarization
|
||||
TEMP=$(mktemp -u).zip
|
||||
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
rm -f "$TEMP"
|
||||
fi
|
||||
|
||||
@@ -154,38 +154,40 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
# Copy MLX metallib (architecture-independent, just use arm64 version)
|
||||
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
fi
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
|
||||
rm -f dist/Ollama.dmg
|
||||
|
||||
@@ -206,7 +208,7 @@ _build_macapp() {
|
||||
rm -f dist/rw*.dmg
|
||||
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f stapler) staple dist/Ollama.dmg
|
||||
else
|
||||
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"
|
||||
|
||||
@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
|
||||
60
scripts/deduplicate_cuda_libs.sh
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/bin/sh
|
||||
#
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
# This script finds identical .so* files in mlx_cuda_* directories that exist
|
||||
# in corresponding cuda_* directories and replaces them with symlinks.
|
||||
#
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
echo "ERROR: No directory specified" >&2
|
||||
echo "Usage: $0 <base_directory>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
base_dir="$1"
|
||||
|
||||
if [ ! -d "${base_dir}" ]; then
|
||||
echo "ERROR: Directory ${base_dir} does not exist" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo "Deduplication complete"
|
||||
@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||
redirectURL, err := challenge.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||
if redirectURL.Host != originalHost {
|
||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
|
||||
|
||||
113
server/auth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
realm string
|
||||
originalHost string
|
||||
wantMismatch bool
|
||||
}{
|
||||
{"https://example.com/token", "example.com", false},
|
||||
{"https://example.com/token", "other.com", true},
|
||||
{"https://example.com/token", "localhost:8000", true},
|
||||
{"https://localhost:5000/token", "localhost:5000", false},
|
||||
{"https://localhost:5000/token", "localhost:6000", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.originalHost, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||
|
||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||
if tt.wantMismatch && !isMismatch {
|
||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||
}
|
||||
if !tt.wantMismatch && isMismatch {
|
||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantRealm, wantService, wantScope string
|
||||
}{
|
||||
{
|
||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||
},
|
||||
{
|
||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||
},
|
||||
{"", "", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := parseRegistryChallenge(tt.input)
|
||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||
tt.input, result.Realm, result.Service, result.Scope,
|
||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURL(t *testing.T) {
|
||||
challenge := registryChallenge{
|
||||
Realm: "https://auth.example.com/token",
|
||||
Service: "registry",
|
||||
Scope: "repo:foo:pull repo:bar:push",
|
||||
}
|
||||
|
||||
u, err := challenge.URL()
|
||||
if err != nil {
|
||||
t.Fatalf("URL() error: %v", err)
|
||||
}
|
||||
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||
}
|
||||
if u.Path != "/token" {
|
||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
if q.Get("service") != "registry" {
|
||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||
}
|
||||
if scopes := q["scope"]; len(scopes) != 2 {
|
||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||
}
|
||||
if q.Get("ts") == "" {
|
||||
t.Error("missing ts")
|
||||
}
|
||||
if q.Get("nonce") == "" {
|
||||
t.Error("missing nonce")
|
||||
}
|
||||
|
||||
// Nonces should differ between calls
|
||||
u2, _ := challenge.URL()
|
||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||
t.Error("nonce should be unique per call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||
challenge := registryChallenge{Realm: "://invalid"}
|
||||
if _, err := challenge.URL(); err == nil {
|
||||
t.Error("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
|
||||
// Handle authentication error with one retry
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
|
||||
}
|
||||
|
||||
func (m *Manifest) RemoveLayers() error {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest != "" {
|
||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
ms, err := Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build set of digests still in use by other manifests
|
||||
inUse := make(map[string]struct{})
|
||||
for _, other := range ms {
|
||||
for _, layer := range append(other.Layers, other.Config) {
|
||||
if layer.Digest != "" {
|
||||
inUse[layer.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove layers not used by any other manifest
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest == "" {
|
||||
continue
|
||||
}
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
114
server/routes.go
@@ -51,7 +51,6 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,29 +163,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -214,12 +190,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -1124,6 +1094,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
QuantizationLevel: m.Config.FileType,
|
||||
}
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
m.System = req.System
|
||||
}
|
||||
@@ -1206,6 +1185,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1574,13 +1557,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// Experimental OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", s.handleImageGeneration)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -1898,6 +1880,62 @@ func toolCallId() string {
|
||||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGeneration(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size (e.g., "1024x768") into width and height
|
||||
width, height := int32(1024), int32(1024)
|
||||
if req.Size != "" {
|
||||
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var image []byte
|
||||
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
}, func(resp llm.CompletionResponse) {
|
||||
if len(resp.Image) > 0 {
|
||||
image = resp.Image
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
@@ -2059,8 +2097,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
if _, ok := c.Get("relax_thinking"); ok {
|
||||
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
|
||||
req.Think = nil
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Load both an imagegen runner and a language model runner
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
||||
modelPath: "/fake/flux/model",
|
||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
langModelRunner := &runnerRef{
|
||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
||||
modelPath: "/fake/llama3/model",
|
||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify both are loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify updateFreeSpace accounts for both
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
||||
TotalMemory: 24 * format.GigaByte,
|
||||
FreeMemory: 24 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Free memory should be reduced by both models
|
||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
24
x/README.md
@@ -1,24 +0,0 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
||||
|
||||
```
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install --component MLX
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
||||
|
||||
## Image Generation
|
||||
|
||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
||||
|
||||
```
|
||||
go build -o imagegen ./x/imagegen/cmd/engine
|
||||
```
|
||||
@@ -41,6 +41,7 @@ var optionLabels = []string{
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"web_search": "Web Search",
|
||||
"web_fetch": "Web Fetch",
|
||||
}
|
||||
|
||||
// ToolDisplayName returns the human-readable display name for a tool.
|
||||
@@ -565,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
// For web fetch, show URL and internet notice
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||
if len(args) > 0 {
|
||||
@@ -1017,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
}
|
||||
}
|
||||
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
// Truncate long URLs
|
||||
if len(url) > 50 {
|
||||
url = url[:47] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||
}
|
||||
|
||||
|
||||
308
x/cmd/run.go
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -130,6 +131,7 @@ type RunOptions struct {
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
Verbose bool
|
||||
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
@@ -178,6 +180,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||
var latest api.ChatResponse
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
@@ -187,6 +190,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
latest = response
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
@@ -483,6 +487,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if opts.Verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
@@ -634,12 +642,13 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
// If yoloMode is true, all tool approvals are skipped.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
||||
// If enableWebsearch is true, the web search tool is registered.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -660,6 +669,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
|
||||
// Register web search and web fetch tools if enabled via flag
|
||||
if enableWebsearch {
|
||||
toolRegistry.RegisterWebSearch()
|
||||
toolRegistry.RegisterWebFetch()
|
||||
}
|
||||
|
||||
if toolRegistry.Has("bash") {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
||||
@@ -667,6 +682,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
|
||||
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
}
|
||||
@@ -677,6 +697,8 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -688,6 +710,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
continue
|
||||
case err != nil:
|
||||
@@ -707,6 +730,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
continue
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load Load a different model")
|
||||
fmt.Fprintln(os.Stderr, " /save Save session as a model")
|
||||
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
@@ -716,6 +743,280 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "history":
|
||||
scanner.HistoryEnable()
|
||||
case "nohistory":
|
||||
scanner.HistoryDisable()
|
||||
case "wordwrap":
|
||||
wordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
wordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
think = &thinkValue
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
think = &api.ThinkValue{Value: false}
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
} else {
|
||||
format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
}
|
||||
case "noformat":
|
||||
format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /set parameter <name> <value>")
|
||||
continue
|
||||
}
|
||||
params := args[3:]
|
||||
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't set parameter: %q\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message>")
|
||||
continue
|
||||
}
|
||||
|
||||
system = strings.Join(args[2:], " ")
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: modelName,
|
||||
Options: options,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, " Model\n")
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
|
||||
if resp.Details.Family != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
|
||||
}
|
||||
if resp.Details.ParameterSize != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
|
||||
}
|
||||
if resp.Details.QuantizationLevel != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
|
||||
}
|
||||
if len(resp.Capabilities) > 0 {
|
||||
caps := make([]string, len(resp.Capabilities))
|
||||
for i, c := range resp.Capabilities {
|
||||
caps[i] = string(c)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Println("No license was specified for this model.")
|
||||
} else {
|
||||
fmt.Println(resp.License)
|
||||
}
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println(" No additional parameters were specified.")
|
||||
} else {
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
}
|
||||
if len(options) > 0 {
|
||||
fmt.Println("\nUser defined parameters:")
|
||||
for k, v := range options {
|
||||
fmt.Printf(" %-30s %v\n", k, v)
|
||||
}
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
case system != "":
|
||||
fmt.Println(system + "\n")
|
||||
case resp.System != "":
|
||||
fmt.Println(resp.System + "\n")
|
||||
default:
|
||||
fmt.Println("No system message was specified for this model.")
|
||||
}
|
||||
case "template":
|
||||
if resp.Template != "" {
|
||||
fmt.Println(resp.Template)
|
||||
} else {
|
||||
fmt.Println("No prompt template was specified for this model.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /load <modelname>")
|
||||
continue
|
||||
}
|
||||
newModelName := args[1]
|
||||
fmt.Printf("Loading model '%s'\n", newModelName)
|
||||
|
||||
// Create progress spinner
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
// Get client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if model exists and get its info
|
||||
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For cloud models, no need to preload
|
||||
if info.RemoteHost == "" {
|
||||
// Preload the model by sending an empty generate request
|
||||
req := &api.GenerateRequest{
|
||||
Model: newModelName,
|
||||
Think: think,
|
||||
}
|
||||
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("error loading model: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
p.StopAndClear()
|
||||
modelName = newModelName
|
||||
messages = []api.Message{}
|
||||
approval.Reset()
|
||||
continue
|
||||
case strings.HasPrefix(line, "/save"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /save <modelname>")
|
||||
continue
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.CreateRequest{
|
||||
Model: args[1],
|
||||
From: modelName,
|
||||
Parameters: options,
|
||||
Messages: messages,
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
continue
|
||||
@@ -727,10 +1028,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
@@ -738,6 +1041,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
Verbose: verbose,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
|
||||
@@ -234,3 +234,17 @@ ollama create z-image
|
||||
3. Copy config files (*.json) as config layers
|
||||
4. Write manifest
|
||||
```
|
||||
|
||||
## FP8 Quantization
|
||||
|
||||
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
cd ./weights/Z-Image-Turbo
|
||||
ollama create z-image-fp8 --quantize fp8
|
||||
```
|
||||
|
||||
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.
|
||||
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imagePath string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imagePath = extractPath(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imagePath, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imagePath string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imagePath = extractPath(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imagePath, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractPath(content string) string {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
return strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imagePath, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imagePath == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
resp.Data[0].URL = "file://" + imagePath
|
||||
} else {
|
||||
data, err := os.ReadFile(imagePath)
|
||||
if err == nil {
|
||||
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
197
x/imagegen/cache/teacache.go
vendored
Normal file
@@ -0,0 +1,197 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package cache provides caching mechanisms for diffusion model inference.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
|
||||
// It caches the transformer output and reuses it when timestep values
|
||||
// are similar between consecutive steps.
|
||||
//
|
||||
// For CFG (classifier-free guidance), it caches pos and neg predictions
|
||||
// separately and always computes CFG fresh to avoid error amplification.
|
||||
//
|
||||
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
|
||||
// https://github.com/ali-vilab/TeaCache
|
||||
type TeaCache struct {
|
||||
// Cached transformer output from last computed step (non-CFG mode)
|
||||
cachedOutput *mlx.Array
|
||||
|
||||
// Cached CFG outputs (pos and neg separately)
|
||||
cachedPosOutput *mlx.Array
|
||||
cachedNegOutput *mlx.Array
|
||||
|
||||
// Previous timestep value for difference calculation
|
||||
prevTimestep float32
|
||||
|
||||
// Accumulated difference for rescaling
|
||||
accumulatedDiff float32
|
||||
|
||||
// Configuration
|
||||
threshold float32 // Threshold for recomputation decision
|
||||
rescaleFactor float32 // Model-specific rescaling factor
|
||||
skipEarlySteps int // Number of early steps to never cache
|
||||
|
||||
// Statistics
|
||||
cacheHits int
|
||||
cacheMisses int
|
||||
}
|
||||
|
||||
// TeaCacheConfig holds configuration for TeaCache.
|
||||
type TeaCacheConfig struct {
|
||||
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
|
||||
// Recommended: 0.05-0.15 for image models
|
||||
Threshold float32
|
||||
|
||||
// Rescale factor to adjust timestep embedding differences.
|
||||
// Model-specific, typically 1.0-2.0
|
||||
RescaleFactor float32
|
||||
|
||||
// SkipEarlySteps: number of early steps to always compute (never cache).
|
||||
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
|
||||
SkipEarlySteps int
|
||||
}
|
||||
|
||||
// DefaultTeaCacheConfig returns default configuration for TeaCache.
|
||||
func DefaultTeaCacheConfig() *TeaCacheConfig {
|
||||
return &TeaCacheConfig{
|
||||
Threshold: 0.1,
|
||||
RescaleFactor: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTeaCache creates a new TeaCache instance.
|
||||
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
|
||||
if cfg == nil {
|
||||
cfg = DefaultTeaCacheConfig()
|
||||
}
|
||||
return &TeaCache{
|
||||
threshold: cfg.Threshold,
|
||||
rescaleFactor: cfg.RescaleFactor,
|
||||
skipEarlySteps: cfg.SkipEarlySteps,
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldCompute determines if we should compute the full forward pass
|
||||
// or reuse the cached output based on timestep similarity.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. First step always computes
|
||||
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
|
||||
// 3. If accumulated difference > threshold, compute new output
|
||||
// 4. Otherwise, reuse cached output
|
||||
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
|
||||
// Always compute early steps (critical for structure)
|
||||
// Check both regular cache and CFG cache
|
||||
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
|
||||
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
|
||||
return true
|
||||
}
|
||||
|
||||
// Compute absolute difference between current and previous timestep
|
||||
diff := timestep - tc.prevTimestep
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
|
||||
// Apply rescaling factor
|
||||
scaledDiff := diff * tc.rescaleFactor
|
||||
|
||||
// Accumulate difference (helps track drift over multiple cached steps)
|
||||
tc.accumulatedDiff += scaledDiff
|
||||
|
||||
// Decision based on accumulated difference
|
||||
if tc.accumulatedDiff > tc.threshold {
|
||||
tc.accumulatedDiff = 0 // Reset accumulator
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
|
||||
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
|
||||
// Free previous cached output
|
||||
if tc.cachedOutput != nil {
|
||||
tc.cachedOutput.Free()
|
||||
}
|
||||
|
||||
// Store new cached values
|
||||
tc.cachedOutput = output
|
||||
tc.prevTimestep = timestep
|
||||
tc.cacheMisses++
|
||||
}
|
||||
|
||||
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
|
||||
// This allows CFG to be computed fresh each step, avoiding error amplification.
|
||||
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
|
||||
// Free previous cached outputs
|
||||
if tc.cachedPosOutput != nil {
|
||||
tc.cachedPosOutput.Free()
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
tc.cachedNegOutput.Free()
|
||||
}
|
||||
|
||||
// Store new cached values
|
||||
tc.cachedPosOutput = posOutput
|
||||
tc.cachedNegOutput = negOutput
|
||||
tc.prevTimestep = timestep
|
||||
tc.cacheMisses++
|
||||
}
|
||||
|
||||
// GetCached returns the cached output (non-CFG mode).
|
||||
func (tc *TeaCache) GetCached() *mlx.Array {
|
||||
tc.cacheHits++
|
||||
return tc.cachedOutput
|
||||
}
|
||||
|
||||
// GetCFGCached returns cached pos and neg outputs for CFG mode.
|
||||
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
|
||||
tc.cacheHits++
|
||||
return tc.cachedPosOutput, tc.cachedNegOutput
|
||||
}
|
||||
|
||||
// HasCFGCache returns true if CFG cache is available.
|
||||
func (tc *TeaCache) HasCFGCache() bool {
|
||||
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
|
||||
}
|
||||
|
||||
// Arrays returns all arrays that should be kept alive.
|
||||
func (tc *TeaCache) Arrays() []*mlx.Array {
|
||||
var arrays []*mlx.Array
|
||||
if tc.cachedOutput != nil {
|
||||
arrays = append(arrays, tc.cachedOutput)
|
||||
}
|
||||
if tc.cachedPosOutput != nil {
|
||||
arrays = append(arrays, tc.cachedPosOutput)
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
arrays = append(arrays, tc.cachedNegOutput)
|
||||
}
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Stats returns cache hit/miss statistics.
|
||||
func (tc *TeaCache) Stats() (hits, misses int) {
|
||||
return tc.cacheHits, tc.cacheMisses
|
||||
}
|
||||
|
||||
// Free releases all cached arrays.
|
||||
func (tc *TeaCache) Free() {
|
||||
if tc.cachedOutput != nil {
|
||||
tc.cachedOutput.Free()
|
||||
tc.cachedOutput = nil
|
||||
}
|
||||
if tc.cachedPosOutput != nil {
|
||||
tc.cachedPosOutput.Free()
|
||||
tc.cachedPosOutput = nil
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
tc.cachedNegOutput.Free()
|
||||
tc.cachedNegOutput = nil
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -39,75 +38,17 @@ func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Steps: 0, // 0 means model default
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// Show displays information about an image generation model.
|
||||
func Show(modelName string, w io.Writer) error {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Count total size
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
|
||||
// Read model_index.json for architecture
|
||||
var architecture string
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
architecture = index.Architecture
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
|
||||
paramCount := totalSize / 2
|
||||
paramStr := formatParamCount(paramCount)
|
||||
|
||||
// Print Model info
|
||||
fmt.Fprintln(w, " Model")
|
||||
if architecture != "" {
|
||||
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
|
||||
}
|
||||
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
|
||||
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
|
||||
fmt.Fprintln(w)
|
||||
|
||||
// Print Capabilities
|
||||
fmt.Fprintln(w, " Capabilities")
|
||||
fmt.Fprintf(w, " %s\n", "image")
|
||||
fmt.Fprintln(w)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatParamCount formats parameter count as human-readable string.
|
||||
func formatParamCount(count int64) string {
|
||||
if count >= 1_000_000_000 {
|
||||
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
|
||||
}
|
||||
if count >= 1_000_000 {
|
||||
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
|
||||
}
|
||||
return fmt.Sprintf("%d", count)
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
cmd.Flags().MarkHidden("width")
|
||||
@@ -121,11 +62,6 @@ func RegisterFlags(cmd *cobra.Command) {
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||
// Verify it's a valid image gen model
|
||||
if ResolveModelName(name) == "" {
|
||||
return fmt.Errorf("unknown image generation model: %s", name)
|
||||
}
|
||||
|
||||
// Get options from flags (with env var defaults)
|
||||
opts := DefaultOptions()
|
||||
if cmd != nil && cmd.Flags() != nil {
|
||||
@@ -155,23 +91,18 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
|
||||
}
|
||||
|
||||
// generateImageWithOptions generates an image with the given options.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
// Note: opts are currently unused as the native API doesn't support size parameters.
|
||||
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -183,8 +114,7 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imagePath string
|
||||
|
||||
var imageBase64 string
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
@@ -203,11 +133,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle final response with image path
|
||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
imagePath = strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -218,9 +146,27 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
if imagePath != "" {
|
||||
displayImageInTerminal(imagePath)
|
||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
||||
if imageBase64 != "" {
|
||||
// Decode base64 and save to CWD
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
// Create filename from prompt
|
||||
safeName := sanitizeFilename(prompt)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to save image: %w", err)
|
||||
}
|
||||
|
||||
displayImageInTerminal(filename)
|
||||
fmt.Printf("Image saved to: %s\n", filename)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -306,7 +252,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imagePath string
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
@@ -326,11 +272,9 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle final response with image path
|
||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
imagePath = strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -342,25 +286,30 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
continue
|
||||
}
|
||||
|
||||
// Copy image to current directory with descriptive name
|
||||
if imagePath != "" {
|
||||
// Save image to current directory with descriptive name
|
||||
if imageBase64 != "" {
|
||||
// Decode base64 image data
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create filename from prompt (sanitized)
|
||||
safeName := sanitizeFilename(line)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
// Copy file to CWD
|
||||
if err := copyFile(imagePath, newName); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
|
||||
displayImageInTerminal(imagePath)
|
||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
||||
} else {
|
||||
displayImageInTerminal(newName)
|
||||
fmt.Printf("Image saved to: %s\n", newName)
|
||||
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
displayImageInTerminal(filename)
|
||||
fmt.Printf("Image saved to: %s\n", filename)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
@@ -381,24 +330,6 @@ func sanitizeFilename(s string) string {
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// copyFile copies a file from src to dst.
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, sourceFile)
|
||||
return err
|
||||
}
|
||||
|
||||
// printInteractiveHelp prints help for interactive mode commands.
|
||||
func printInteractiveHelp(opts ImageGenOptions) {
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
@@ -509,10 +440,7 @@ func displayImageInTerminal(imagePath string) bool {
|
||||
// Send in chunks for large images
|
||||
const chunkSize = 4096
|
||||
for i := 0; i < len(encoded); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(encoded) {
|
||||
end = len(encoded)
|
||||
}
|
||||
end := min(i+chunkSize, len(encoded))
|
||||
chunk := encoded[i:end]
|
||||
|
||||
if i == 0 {
|
||||
|
||||
@@ -29,9 +29,10 @@ const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
@@ -58,18 +59,77 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
|
||||
// When quantize is true, returns multiple layers (weight + scales)
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
|
||||
if doQuantize {
|
||||
// Check if quantization is supported
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor (affine mode returns weight, scales, qbiases)
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales (use _scale suffix convention)
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name, // Keep original name for weight
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale", // Add _scale suffix
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, imagegen.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias", // Add _qbias suffix
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// Non-quantized path: just create a single layer
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
return nil, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
return []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -119,7 +179,7 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
p.Add("imagegen", spinner)
|
||||
}
|
||||
|
||||
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
120
x/imagegen/client/quantize.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
if _, err := io.Copy(tmpFile, r); err != nil {
|
||||
tmpFile.Close()
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Load the tensor using MLX's native loader
|
||||
st, err := mlx.LoadSafetensorsNative(tmpPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
|
||||
}
|
||||
defer st.Free()
|
||||
|
||||
// Get the tensor (it's stored as "data" in our minimal safetensors format)
|
||||
arr := st.Get("data")
|
||||
if arr == nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
|
||||
}
|
||||
|
||||
// Convert to BFloat16 if needed (quantize expects float type)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize with affine mode: group_size=32, bits=8
|
||||
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
mlx.Eval(qweight, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qweight, scales)
|
||||
}
|
||||
|
||||
// Get shapes
|
||||
qweightShape = qweight.Shape()
|
||||
scalesShape = scales.Shape()
|
||||
|
||||
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
|
||||
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
|
||||
defer os.Remove(qweightPath)
|
||||
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
|
||||
}
|
||||
qweightData, err = os.ReadFile(qweightPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
|
||||
}
|
||||
|
||||
// Save scales using MLX's native safetensors
|
||||
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
|
||||
defer os.Remove(scalesPath)
|
||||
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
|
||||
}
|
||||
scalesData, err = os.ReadFile(scalesPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
|
||||
}
|
||||
|
||||
// Affine mode returns qbiases for zero-point offset
|
||||
if qbiases != nil {
|
||||
qbiasShape = qbiases.Shape()
|
||||
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
|
||||
defer os.Remove(qbiasPath)
|
||||
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
|
||||
}
|
||||
qbiasData, err = os.ReadFile(qbiasPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
func QuantizeSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
func ensureTempDir() string {
|
||||
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
|
||||
os.MkdirAll(tmpDir, 0755)
|
||||
return tmpDir
|
||||
}
|
||||
18
x/imagegen/client/quantize_stub.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
func QuantizeSupported() bool {
|
||||
return false
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
@@ -48,7 +49,7 @@ func main() {
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
@@ -67,6 +68,9 @@ func main() {
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
|
||||
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
|
||||
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@@ -99,13 +103,17 @@ func main() {
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
LayerCache: *layerCache,
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
TeaCache: *teaCache,
|
||||
TeaCacheThreshold: float32(*teaCacheThreshold),
|
||||
FusedQKV: *fusedQKV,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
|
||||
@@ -40,10 +40,12 @@ type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo)
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
@@ -74,7 +76,11 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
|
||||
quantizeMsg := ""
|
||||
if quantize == "fp8" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to fp8"
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
@@ -83,16 +89,30 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Count parameters from original tensor shape
|
||||
if len(td.Shape) > 0 {
|
||||
numElements := int64(1)
|
||||
for _, dim := range td.Shape {
|
||||
numElements *= int64(dim)
|
||||
}
|
||||
totalParams += numElements
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
|
||||
|
||||
// Determine if this tensor should be quantized
|
||||
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
|
||||
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
@@ -122,7 +142,7 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
|
||||
var r io.Reader
|
||||
|
||||
// For model_index.json, normalize to Ollama format
|
||||
// For model_index.json, normalize to Ollama format and add metadata
|
||||
if cfgPath == "model_index.json" {
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
@@ -141,6 +161,16 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
}
|
||||
delete(cfg, "_diffusers_version")
|
||||
|
||||
// Add parameter count (counted from tensor shapes during import)
|
||||
cfg["parameter_count"] = totalParams
|
||||
|
||||
// Add quantization info
|
||||
if quantize == "fp8" {
|
||||
cfg["quantization"] = "FP8"
|
||||
} else {
|
||||
cfg["quantization"] = "BF16"
|
||||
}
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
||||
|
||||
@@ -60,9 +60,12 @@ func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
||||
}
|
||||
|
||||
// Transform to [H, W, C] for image conversion
|
||||
img := mlx.Squeeze(arr, 0)
|
||||
img = mlx.Transpose(img, 1, 2, 0)
|
||||
img = mlx.Contiguous(img)
|
||||
// Free intermediate arrays to avoid memory leak
|
||||
squeezed := mlx.Squeeze(arr, 0)
|
||||
transposed := mlx.Transpose(squeezed, 1, 2, 0)
|
||||
squeezed.Free()
|
||||
img := mlx.Contiguous(transposed)
|
||||
transposed.Free()
|
||||
mlx.Eval(img)
|
||||
|
||||
imgShape := img.Shape()
|
||||
|
||||
@@ -175,3 +175,63 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
@@ -95,8 +95,3 @@ func EstimateVRAM(modelName string) uint64 {
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
}
|
||||
|
||||
@@ -94,13 +94,6 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# MLX Memory Management
|
||||
|
||||
| This package will get consolidated with `x/ml/backend/mlx` in the future.
|
||||
|
||||
## Automatic Tracking
|
||||
|
||||
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
|
||||
|
||||
@@ -607,6 +607,11 @@ func (a *Array) Valid() bool {
|
||||
return a != nil && a.c.ctx != nil
|
||||
}
|
||||
|
||||
// Kept returns true if the array is marked to survive Eval() cleanup.
|
||||
func (a *Array) Kept() bool {
|
||||
return a != nil && a.kept
|
||||
}
|
||||
|
||||
func int32ToCInt(s []int32) *C.int {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
@@ -1480,6 +1485,44 @@ func (a *Array) ItemInt32() int32 {
|
||||
return int32(val)
|
||||
}
|
||||
|
||||
// Bytes copies the raw bytes out of the array without type conversion.
|
||||
// Works with common dtypes (float32, int32, uint32, uint8).
|
||||
// For non-contiguous arrays, call Contiguous() first.
|
||||
// Note: Triggers cleanup of non-kept arrays.
|
||||
func (a *Array) Bytes() []byte {
|
||||
cleanup()
|
||||
nbytes := a.Nbytes()
|
||||
if nbytes == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get raw pointer based on dtype
|
||||
var ptr unsafe.Pointer
|
||||
switch a.Dtype() {
|
||||
case DtypeFloat32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
|
||||
case DtypeInt32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
|
||||
case DtypeUint32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
|
||||
case DtypeUint8:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
|
||||
default:
|
||||
// For other types (bf16, f16, etc), convert to float32
|
||||
arr := AsType(a, DtypeFloat32)
|
||||
arr.Eval()
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
|
||||
nbytes = arr.Nbytes()
|
||||
}
|
||||
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
data := make([]byte, nbytes)
|
||||
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
|
||||
return data
|
||||
}
|
||||
|
||||
// ============ Utility ============
|
||||
|
||||
// String returns a string representation
|
||||
@@ -1658,6 +1701,34 @@ func (s *SafetensorsFile) Free() {
|
||||
C.mlx_map_string_to_string_free(s.metadata)
|
||||
}
|
||||
|
||||
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
|
||||
// This correctly handles all dtypes including uint32 for quantized weights.
|
||||
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
// Create the map
|
||||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
// Add each array to the map
|
||||
for name, arr := range arrays {
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
// Create empty metadata (optional)
|
||||
cMeta := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMeta)
|
||||
|
||||
// Save
|
||||
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
|
||||
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============ NPY Loading ============
|
||||
|
||||
// LoadNpy loads a numpy array from an npy file
|
||||
@@ -1986,7 +2057,8 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
|
||||
// Returns (quantized_weights, scales, biases).
|
||||
// groupSize: number of elements quantized together (default 64)
|
||||
// bits: bits per element, 2, 4, or 8 (default 4)
|
||||
// mode: "affine" (default) or "mxfp4"
|
||||
// mode: "affine" (default), "mxfp4", or "mxfp8"
|
||||
// Note: mxfp8 mode returns nil biases (only weights and scales)
|
||||
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
@@ -1995,14 +2067,21 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
res := C.mlx_vector_array_new()
|
||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
||||
|
||||
// Result is a vector of 3 arrays: [weights, scales, biases]
|
||||
// Result is a vector of arrays: [weights, scales, biases?]
|
||||
// mxfp8 mode returns only 2 elements (no biases)
|
||||
vecSize := int(C.mlx_vector_array_size(res))
|
||||
var w0, w1, w2 C.mlx_array
|
||||
C.mlx_vector_array_get(&w0, res, 0)
|
||||
C.mlx_vector_array_get(&w1, res, 1)
|
||||
C.mlx_vector_array_get(&w2, res, 2)
|
||||
if vecSize >= 3 {
|
||||
C.mlx_vector_array_get(&w2, res, 2)
|
||||
}
|
||||
C.mlx_vector_array_free(res)
|
||||
|
||||
return newArray(w0), newArray(w1), newArray(w2)
|
||||
if vecSize >= 3 {
|
||||
return newArray(w0), newArray(w1), newArray(w2)
|
||||
}
|
||||
return newArray(w0), newArray(w1), nil
|
||||
}
|
||||
|
||||
// Dequantize reconstructs weights from quantized form.
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -172,7 +173,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 30
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
@@ -222,6 +223,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
@@ -264,10 +273,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// True CFG: run twice and combine with norm rescaling
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
L := batchedOutput.Shape()[1]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
@@ -305,6 +323,9 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
|
||||
@@ -241,6 +241,14 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Encode all input images to latents and concatenate
|
||||
fmt.Println("Encoding images to latents...")
|
||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||
@@ -291,11 +299,18 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Tile inputs: [1, L, D] -> [2, L, D]
|
||||
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
|
||||
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
|
||||
|
||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||
} else {
|
||||
@@ -317,6 +332,9 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
imageLatentsPacked.Free()
|
||||
|
||||
@@ -28,12 +28,12 @@ type Qwen3Config struct {
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OProj *nn.Linear `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
@@ -136,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj *nn.Linear `weight:"gate_proj"`
|
||||
UpProj *nn.Linear `weight:"up_proj"`
|
||||
DownProj *nn.Linear `weight:"down_proj"`
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
|
||||
@@ -36,8 +36,8 @@ type TransformerConfig struct {
|
||||
// TimestepEmbedder creates sinusoidal timestep embeddings
|
||||
// Output dimension is 256 (fixed), used for AdaLN modulation
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 *nn.Linear `weight:"mlp.0"`
|
||||
Linear2 *nn.Linear `weight:"mlp.2"`
|
||||
Linear1 nn.LinearLayer `weight:"mlp.0"`
|
||||
Linear2 nn.LinearLayer `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
|
||||
// XEmbedder embeds image patches to model dimension
|
||||
type XEmbedder struct {
|
||||
Linear *nn.Linear `weight:"2-1"`
|
||||
Linear nn.LinearLayer `weight:"2-1"`
|
||||
}
|
||||
|
||||
// Forward embeds patchified image latents
|
||||
@@ -86,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear *nn.Linear `weight:"1"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
@@ -100,12 +100,13 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
|
||||
|
||||
// FeedForward implements SwiGLU FFN
|
||||
type FeedForward struct {
|
||||
W1 *nn.Linear `weight:"w1"` // gate projection
|
||||
W2 *nn.Linear `weight:"w2"` // down projection
|
||||
W3 *nn.Linear `weight:"w3"` // up projection
|
||||
W1 nn.LinearLayer `weight:"w1"` // gate projection
|
||||
W2 nn.LinearLayer `weight:"w2"` // down projection
|
||||
W3 nn.LinearLayer `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
@@ -115,6 +116,7 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Reshape for matmul
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
|
||||
gate := ff.W1.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := ff.W3.Forward(x)
|
||||
@@ -126,17 +128,69 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Attention implements multi-head attention with QK norm
|
||||
type Attention struct {
|
||||
ToQ *nn.Linear `weight:"to_q"`
|
||||
ToK *nn.Linear `weight:"to_k"`
|
||||
ToV *nn.Linear `weight:"to_v"`
|
||||
ToOut *nn.Linear `weight:"to_out.0"`
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Dim int32
|
||||
Scale float32
|
||||
// Fused QKV (computed at init time for efficiency, not loaded from weights)
|
||||
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
// Computed fields (not loaded from weights)
|
||||
NHeads int32 `weight:"-"`
|
||||
HeadDim int32 `weight:"-"`
|
||||
Dim int32 `weight:"-"`
|
||||
Scale float32 `weight:"-"`
|
||||
}
|
||||
|
||||
// FuseQKV creates a fused QKV projection by concatenating weights.
|
||||
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
|
||||
// Note: Fusion is skipped for quantized weights as it would require complex
|
||||
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
|
||||
// the ~5% fusion benefit.
|
||||
func (attn *Attention) FuseQKV() {
|
||||
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip fusion for quantized weights - type assert to check
|
||||
toQ, qOk := attn.ToQ.(*nn.Linear)
|
||||
toK, kOk := attn.ToK.(*nn.Linear)
|
||||
toV, vOk := attn.ToV.(*nn.Linear)
|
||||
if !qOk || !kOk || !vOk {
|
||||
// One or more are QuantizedLinear, skip fusion
|
||||
return
|
||||
}
|
||||
|
||||
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
|
||||
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
|
||||
qWeight := toQ.Weight
|
||||
kWeight := toK.Weight
|
||||
vWeight := toV.Weight
|
||||
|
||||
// Concatenate along output dimension (axis 0)
|
||||
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
|
||||
|
||||
// Evaluate fused weight to ensure it's materialized
|
||||
mlx.Eval(fusedWeight)
|
||||
|
||||
// Create fused linear layer
|
||||
fusedLinear := &nn.Linear{Weight: fusedWeight}
|
||||
|
||||
// Handle bias if present
|
||||
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
|
||||
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
|
||||
mlx.Eval(fusedBias)
|
||||
fusedLinear.Bias = fusedBias
|
||||
}
|
||||
|
||||
attn.ToQKV = fusedLinear
|
||||
attn.Fused = true
|
||||
}
|
||||
|
||||
// Forward computes attention
|
||||
@@ -146,11 +200,24 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Project Q, K, V
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
q := attn.ToQ.Forward(xFlat)
|
||||
k := attn.ToK.Forward(xFlat)
|
||||
v := attn.ToV.Forward(xFlat)
|
||||
|
||||
var q, k, v *mlx.Array
|
||||
if attn.Fused && attn.ToQKV != nil {
|
||||
// Fused QKV path: single matmul then split
|
||||
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
|
||||
|
||||
// Split into Q, K, V along last dimension
|
||||
// Each has shape [B*L, dim]
|
||||
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
|
||||
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
|
||||
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
|
||||
} else {
|
||||
// Separate Q, K, V projections
|
||||
q = attn.ToQ.Forward(xFlat)
|
||||
k = attn.ToK.Forward(xFlat)
|
||||
v = attn.ToV.Forward(xFlat)
|
||||
}
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
@@ -227,7 +294,7 @@ type TransformerBlock struct {
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
@@ -281,8 +348,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
|
||||
|
||||
// FinalLayer outputs the denoised patches
|
||||
type FinalLayer struct {
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
@@ -350,12 +417,11 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
// Load weights from tensor blobs with BF16 conversion
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
@@ -377,7 +443,7 @@ func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
|
||||
func (m *Transformer) initComputedFields() {
|
||||
cfg := m.TransformerConfig
|
||||
m.TEmbed.FreqEmbedSize = 256
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
|
||||
m.CapEmbed.Norm.Eps = 1e-6
|
||||
|
||||
for _, block := range m.NoiseRefiners {
|
||||
@@ -391,6 +457,20 @@ func (m *Transformer) initComputedFields() {
|
||||
}
|
||||
}
|
||||
|
||||
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
|
||||
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
|
||||
func (m *Transformer) FuseAllQKV() {
|
||||
for _, block := range m.NoiseRefiners {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
for _, block := range m.ContextRefiners {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
for _, block := range m.Layers {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
}
|
||||
|
||||
// initTransformerBlock sets computed fields on a transformer block
|
||||
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||
block.Dim = cfg.Dim
|
||||
@@ -404,7 +484,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
|
||||
|
||||
// Init feedforward OutDim
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
|
||||
|
||||
// Set eps on all RMSNorm layers
|
||||
block.AttentionNorm1.Eps = cfg.NormEps
|
||||
@@ -423,6 +503,8 @@ type RoPECache struct {
|
||||
UnifiedSin *mlx.Array
|
||||
ImgLen int32
|
||||
CapLen int32
|
||||
GridH int32 // Image token grid height
|
||||
GridW int32 // Image token grid width
|
||||
}
|
||||
|
||||
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
|
||||
@@ -456,6 +538,8 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
|
||||
UnifiedSin: unifiedSin,
|
||||
ImgLen: imgLen,
|
||||
CapLen: capLen,
|
||||
GridH: hTok,
|
||||
GridW: wTok,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -104,6 +104,8 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
|
||||
groupSize := C / gn.NumGroups
|
||||
|
||||
// Keep the input - we need it for slicing tiles later
|
||||
// Track if we were the ones who kept it, so we can restore state after
|
||||
wasKept := x.Kept()
|
||||
mlx.Keep(x)
|
||||
|
||||
// Compute per-group mean and variance using flattened spatial dimensions
|
||||
@@ -205,6 +207,10 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
|
||||
}
|
||||
|
||||
// Clean up kept arrays
|
||||
// Restore x's kept state - only free if we were the ones who kept it
|
||||
if !wasKept {
|
||||
x.Free()
|
||||
}
|
||||
mean.Free()
|
||||
invStd.Free()
|
||||
if weightGN != nil {
|
||||
@@ -734,18 +740,26 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
h := vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
|
||||
prev := h
|
||||
h = vae.MidBlock.Forward(h)
|
||||
prev.Free()
|
||||
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
prev = h
|
||||
h = upBlock.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
prev := h
|
||||
prev = h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
||||
prev.Free()
|
||||
|
||||
prev = h
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
mlx.Eval(h)
|
||||
prev.Free()
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
@@ -754,7 +768,6 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
|
||||
return h
|
||||
|
||||
@@ -26,10 +26,12 @@ type GenerateConfig struct {
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// Layer caching options (speedup via shallow layer reuse)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 15)
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
|
||||
|
||||
// Fused QKV (fuse Q/K/V projections into single matmul)
|
||||
FusedQKV bool // Enable fused QKV projection (default: false)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
@@ -42,6 +44,7 @@ type Model struct {
|
||||
TextEncoder *Qwen3TextEncoder
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
qkvFused bool // Track if QKV has been fused (do only once)
|
||||
}
|
||||
|
||||
// Load loads the Z-Image model from ollama blob storage.
|
||||
@@ -191,18 +194,22 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Turbo default
|
||||
cfg.Steps = 9 // Z-Image turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.LayerCache {
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 15 // Half of 30 layers
|
||||
}
|
||||
// TeaCache enabled by default
|
||||
cfg.TeaCache = true
|
||||
if cfg.TeaCacheThreshold <= 0 {
|
||||
cfg.TeaCacheThreshold = 0.15
|
||||
}
|
||||
|
||||
// Enable fused QKV if requested (only fuse once)
|
||||
if cfg.FusedQKV && !m.qkvFused {
|
||||
m.Transformer.FuseAllQKV()
|
||||
m.qkvFused = true
|
||||
fmt.Println(" Fused QKV enabled")
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
@@ -260,12 +267,54 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
mlx.Eval(ropeCache.UnifiedCos)
|
||||
}
|
||||
|
||||
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
|
||||
cfg.CacheLayers, cfg.CacheInterval)
|
||||
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// TeaCache for timestep-aware caching
|
||||
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
|
||||
var teaCache *cache.TeaCache
|
||||
if cfg.TeaCache {
|
||||
skipEarly := 0
|
||||
if useCFG {
|
||||
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
|
||||
}
|
||||
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
|
||||
Threshold: cfg.TeaCacheThreshold,
|
||||
RescaleFactor: 1.0,
|
||||
SkipEarlySteps: skipEarly,
|
||||
})
|
||||
if useCFG {
|
||||
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
|
||||
} else {
|
||||
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup frees all kept arrays when we need to abort early
|
||||
cleanup := func() {
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgCos.Free()
|
||||
ropeCache.ImgSin.Free()
|
||||
ropeCache.CapCos.Free()
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
if teaCache != nil {
|
||||
teaCache.Free()
|
||||
}
|
||||
latents.Free()
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
@@ -277,6 +326,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cleanup()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
@@ -289,50 +339,77 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
}
|
||||
|
||||
tCurr := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
var noisePred *mlx.Array
|
||||
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
// TeaCache: check if we should compute or reuse cached output
|
||||
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
|
||||
|
||||
var output *mlx.Array
|
||||
if stepCache != nil {
|
||||
// Use layer caching for faster inference
|
||||
if shouldCompute {
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
// Note: CFG with layer cache shares the cache between pos/neg
|
||||
// This is approximate but fast - neg prompt uses same cached shallow layers
|
||||
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Tile patches: [1, L, D] -> [2, L, D]
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
// Tile timestep: [1] -> [2]
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
outputShape := batchedOutput.Shape()
|
||||
L := outputShape[1]
|
||||
D := outputShape[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
// Convert to noise predictions (unpatchify and negate)
|
||||
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
posPred = mlx.Neg(posPred)
|
||||
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
negPred = mlx.Neg(negPred)
|
||||
|
||||
// Cache pos/neg separately for TeaCache
|
||||
if teaCache != nil {
|
||||
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
|
||||
mlx.Keep(teaCache.Arrays()...)
|
||||
}
|
||||
|
||||
// Apply CFG: noisePred = neg + scale * (pos - neg)
|
||||
diff := mlx.Sub(posPred, negPred)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
}
|
||||
} else {
|
||||
// Standard forward without caching
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
noisePred = mlx.Add(negPred, scaledDiff)
|
||||
} else {
|
||||
// Non-CFG forward pass
|
||||
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
}
|
||||
}
|
||||
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
|
||||
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
// Update TeaCache
|
||||
if teaCache != nil {
|
||||
teaCache.UpdateCache(noisePred, tCurr)
|
||||
mlx.Keep(teaCache.Arrays()...)
|
||||
}
|
||||
}
|
||||
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
|
||||
// CFG mode: get cached pos/neg and compute CFG fresh
|
||||
posPred, negPred := teaCache.GetCFGCached()
|
||||
diff := mlx.Sub(posPred, negPred)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
noisePred = mlx.Add(negPred, scaledDiff)
|
||||
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
|
||||
} else {
|
||||
// Non-CFG mode: reuse cached noise prediction
|
||||
noisePred = teaCache.GetCached()
|
||||
fmt.Printf(" [TeaCache: reusing cached output]\n")
|
||||
}
|
||||
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep latents and any cached arrays
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
@@ -361,8 +438,14 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
if teaCache != nil {
|
||||
hits, misses := teaCache.Stats()
|
||||
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
|
||||
hits, misses, float64(hits)/float64(hits+misses)*100)
|
||||
teaCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode
|
||||
|
||||
@@ -10,6 +10,13 @@ type Layer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// LinearLayer is an interface for linear layers (both regular and quantized).
|
||||
// This allows swapping between Linear and QuantizedLinear at runtime.
|
||||
type LinearLayer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
OutputDim() int32 // Returns the output dimension of the layer
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
|
||||
type Linear struct {
|
||||
@@ -49,6 +56,11 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Linear(x, w)
|
||||
}
|
||||
|
||||
// OutputDim returns the output dimension of the linear layer.
|
||||
func (l *Linear) OutputDim() int32 {
|
||||
return l.Weight.Shape()[0]
|
||||
}
|
||||
|
||||
// ToQuantized converts this Linear to a QuantizedLinear.
|
||||
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
|
||||
@@ -84,6 +96,13 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return out
|
||||
}
|
||||
|
||||
// OutputDim returns the output dimension of the quantized linear layer.
|
||||
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
|
||||
// The output dimension is the first dimension of the weight.
|
||||
func (ql *QuantizedLinear) OutputDim() int32 {
|
||||
return ql.Weight.Shape()[0]
|
||||
}
|
||||
|
||||
// RMSNorm represents an RMS normalization layer.
|
||||
type RMSNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
|
||||
22
x/imagegen/quantize.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is true, returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -34,7 +33,8 @@ type Request struct {
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
type Response struct {
|
||||
Content string `json:"content"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
@@ -136,16 +136,8 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Apply defaults
|
||||
if req.Width <= 0 {
|
||||
req.Width = 1024
|
||||
}
|
||||
if req.Height <= 0 {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
req.Steps = 9
|
||||
}
|
||||
// Model applies its own defaults for width/height/steps
|
||||
// Only seed needs to be set here if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
@@ -191,10 +183,10 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Save image
|
||||
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
|
||||
if err := imagegen.SaveImage(img, outPath); err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
@@ -204,11 +196,12 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
mlx.MetalResetPeakMemory()
|
||||
|
||||
// Send final response
|
||||
// Send final response with image data
|
||||
resp := Response{
|
||||
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
|
||||
Done: true,
|
||||
Image: imageData,
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// WeightSource is an interface for loading weights.
|
||||
@@ -102,6 +103,22 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
}
|
||||
}
|
||||
|
||||
// Handle nn.LinearLayer interface fields specially
|
||||
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
|
||||
if !hasTag {
|
||||
continue // no tag = skip
|
||||
}
|
||||
layer, err := LoadLinearLayer(weights, fullPath)
|
||||
if err != nil {
|
||||
if !optional {
|
||||
*errs = append(*errs, fullPath+": "+err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(layer))
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle by kind
|
||||
switch fieldVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
@@ -176,3 +193,64 @@ func joinPath(prefix, suffix string) string {
|
||||
}
|
||||
return prefix + "." + suffix
|
||||
}
|
||||
|
||||
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
|
||||
// If {path}.weight_scale exists, dequantizes the weights.
|
||||
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
|
||||
// Check if this is a quantized layer by looking for scale tensor
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
|
||||
}
|
||||
|
||||
scales, err := weights.GetTensor(scalePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
|
||||
}
|
||||
|
||||
// Bias is optional
|
||||
var bias *mlx.Array
|
||||
biasPath := path + ".bias"
|
||||
if weights.HasTensor(biasPath) {
|
||||
bias, _ = weights.GetTensor(biasPath)
|
||||
}
|
||||
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: 32,
|
||||
Bits: 8,
|
||||
Mode: "affine",
|
||||
}, nil
|
||||
}
|
||||
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
|
||||
return nn.NewLinear(dequantized, bias), nil
|
||||
}
|
||||
|
||||
// Load as regular Linear
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Bias is optional
|
||||
var bias *mlx.Array
|
||||
biasPath := path + ".bias"
|
||||
if weights.HasTensor(biasPath) {
|
||||
bias, _ = weights.GetTensor(biasPath)
|
||||
}
|
||||
|
||||
return nn.NewLinear(weight, bias), nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -14,7 +15,9 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -23,6 +26,11 @@ import (
|
||||
)
|
||||
|
||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||
//
|
||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||
// like any other model. The plan is to eventually bring this into the llm/ package
|
||||
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
|
||||
// separate allows for independent iteration on image generation support.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
@@ -35,21 +43,6 @@ type Server struct {
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// completionRequest is sent to the subprocess
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// completionResponse is received from the subprocess
|
||||
type completionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
@@ -69,7 +62,7 @@ func NewServer(modelName string) (*Server, error) {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the ollama executable path
|
||||
// Get the ollama-mlx executable path (in same directory as current executable)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
@@ -77,11 +70,42 @@ func NewServer(modelName string) (*Server, error) {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
|
||||
|
||||
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
// Append existing LD_LIBRARY_PATH if set
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add LD_LIBRARY_PATH in cmd.Env
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
@@ -105,14 +129,13 @@ func NewServer(modelName string) (*Server, error) {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
// Capture last error line for better error reporting
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting image runner subprocess", "model", modelName, "port", port)
|
||||
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||
}
|
||||
@@ -137,7 +160,6 @@ func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load is called by the scheduler after the server is created.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -170,20 +192,16 @@ func (s *Server) waitUntilRunning() error {
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
// Include last stderr line for better error context
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||
// Include recent stderr lines for better error context
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
case <-ticker.C:
|
||||
@@ -195,44 +213,39 @@ func (s *Server) waitUntilRunning() error {
|
||||
}
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
// getLastErr returns the last stderr line.
|
||||
func (s *Server) getLastErr() string {
|
||||
s.lastErrLock.Lock()
|
||||
defer s.lastErrLock.Unlock()
|
||||
return s.lastErr
|
||||
}
|
||||
|
||||
// Completion generates an image from the prompt via the subprocess.
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
|
||||
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Build request
|
||||
creq := completionRequest{
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Seed: time.Now().UnixNano(),
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Seed: seed,
|
||||
}
|
||||
|
||||
if req.Options != nil {
|
||||
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||
creq.Width = int32(req.Options.NumCtx)
|
||||
}
|
||||
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||
creq.Height = int32(req.Options.NumGPU)
|
||||
}
|
||||
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||
creq.Steps = req.Options.NumPredict
|
||||
}
|
||||
if req.Options.Seed > 0 {
|
||||
creq.Seed = int64(req.Options.Seed)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode request body
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send request to subprocess
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
@@ -247,22 +260,40 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Stream responses
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||
for scanner.Scan() {
|
||||
var cresp completionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||
// Parse subprocess response (has singular "image" field)
|
||||
var raw struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
continue
|
||||
}
|
||||
fn(llm.CompletionResponse{
|
||||
Content: cresp.Content,
|
||||
Done: cresp.Done,
|
||||
})
|
||||
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
Total: raw.Total,
|
||||
}
|
||||
if raw.Image != "" {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
|
||||
cresp.Image = data
|
||||
}
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
if cresp.Done {
|
||||
break
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,22 +335,18 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding is not supported for image generation models.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||
return nil, 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
// Tokenize is not supported for image generation models.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("tokenize not supported for image generation models")
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
// Detokenize is not supported for image generation models.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("detokenize not supported for image generation models")
|
||||
return "", errors.New("not supported")
|
||||
}
|
||||
|
||||
// Pid returns the subprocess PID.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -329,17 +356,9 @@ func (s *Server) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// GetPort returns the subprocess port.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
func (s *Server) GetPort() int { return s.port }
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
|
||||
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns true if the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
|
||||
@@ -45,24 +45,33 @@ func download(ctx context.Context, opts DownloadOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter existing
|
||||
var blobs []Blob
|
||||
// Calculate total from all blobs (for accurate progress reporting on resume)
|
||||
var total int64
|
||||
for _, b := range opts.Blobs {
|
||||
total += b.Size
|
||||
}
|
||||
|
||||
// Filter out already-downloaded blobs and track completed bytes
|
||||
var blobs []Blob
|
||||
var alreadyCompleted int64
|
||||
for _, b := range opts.Blobs {
|
||||
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
|
||||
if opts.Logger != nil {
|
||||
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
|
||||
}
|
||||
alreadyCompleted += b.Size
|
||||
continue
|
||||
}
|
||||
blobs = append(blobs, b)
|
||||
total += b.Size
|
||||
}
|
||||
if len(blobs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
token := opts.Token
|
||||
progress := newProgressTracker(total, opts.Progress)
|
||||
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
|
||||
|
||||
d := &downloader{
|
||||
client: cmp.Or(opts.Client, defaultClient),
|
||||
baseURL: opts.BaseURL,
|
||||
@@ -72,7 +81,7 @@ func download(ctx context.Context, opts DownloadOptions) error {
|
||||
getToken: opts.GetToken,
|
||||
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
||||
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
|
||||
progress: newProgressTracker(total, opts.Progress),
|
||||
progress: progress,
|
||||
speeds: &speedTracker{},
|
||||
logger: opts.Logger,
|
||||
}
|
||||
|
||||
@@ -110,8 +110,6 @@ var defaultClient = &http.Client{
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
Timeout: 5 * time.Minute,
|
||||
// Don't follow redirects automatically - we handle them manually
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
|
||||
@@ -284,6 +284,83 @@ func TestDownloadSkipsExisting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadResumeProgressTotal(t *testing.T) {
|
||||
// Test that when resuming a download with some blobs already present:
|
||||
// 1. Total reflects ALL blob sizes (not just remaining)
|
||||
// 2. Completed starts at the size of already-downloaded blobs
|
||||
serverDir := t.TempDir()
|
||||
blob1, data1 := createTestBlob(t, serverDir, 1000)
|
||||
blob2, data2 := createTestBlob(t, serverDir, 2000)
|
||||
blob3, data3 := createTestBlob(t, serverDir, 3000)
|
||||
|
||||
// Pre-populate client with blob1 and blob2 (simulating partial download)
|
||||
clientDir := t.TempDir()
|
||||
for _, b := range []struct {
|
||||
blob Blob
|
||||
data []byte
|
||||
}{{blob1, data1}, {blob2, data2}} {
|
||||
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, b.data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
digest := filepath.Base(r.URL.Path)
|
||||
path := filepath.Join(serverDir, digestToPath(digest))
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var firstCompleted, firstTotal int64
|
||||
var gotFirstProgress bool
|
||||
var mu sync.Mutex
|
||||
|
||||
err := Download(context.Background(), DownloadOptions{
|
||||
Blobs: []Blob{blob1, blob2, blob3},
|
||||
BaseURL: server.URL,
|
||||
DestDir: clientDir,
|
||||
Concurrency: 1,
|
||||
Progress: func(completed, total int64) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if !gotFirstProgress {
|
||||
firstCompleted = completed
|
||||
firstTotal = total
|
||||
gotFirstProgress = true
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
// Total should be sum of ALL blobs, not just blob3
|
||||
expectedTotal := blob1.Size + blob2.Size + blob3.Size
|
||||
if firstTotal != expectedTotal {
|
||||
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
|
||||
}
|
||||
|
||||
// First progress call should show already-completed bytes from blob1+blob2
|
||||
expectedCompleted := blob1.Size + blob2.Size
|
||||
if firstCompleted < expectedCompleted {
|
||||
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
|
||||
}
|
||||
|
||||
// Verify blob3 was downloaded
|
||||
verifyBlob(t, clientDir, blob3, data3)
|
||||
}
|
||||
|
||||
func TestDownloadDigestMismatch(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return wrong data
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||
ErrNotSupported = errors.New("model does not support operation")
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
// ** used by model implementations **
|
||||
|
||||
// SetLayer sets the active layer of the cache
|
||||
SetLayer(layer int)
|
||||
|
||||
// Get returns the history of key and value tensors plus a mask
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||
|
||||
// Put stores a batch of key and value in the cache
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters.
|
||||
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||
// dtype: The data type for storing cache entries
|
||||
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||
// capacity: The number of cache entries to store, per sequence
|
||||
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
// If an error occurs, the entire context for the sequence should be
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
@@ -1,797 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "errors"
|
||||
// "fmt"
|
||||
// "log/slog"
|
||||
// "math"
|
||||
// "slices"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
|
||||
// // Causal cache stores K and V tensors according to their position in the
|
||||
// // sequence. Returns the history and a mask for attending to past tokens
|
||||
// //
|
||||
// // The tensors are of shape embed dim, kv heads, batch size
|
||||
// // The mask is of shape history size, batch size
|
||||
// type Causal struct {
|
||||
// DType ml.DType
|
||||
|
||||
// // swaWindowSize is the number of tokens that will be included in the mask
|
||||
// // during attention operations. swaMemorySize is the number of tokens that
|
||||
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
||||
// // for unlimited or if sliding window attention is not being used.
|
||||
// swaWindowSize int32
|
||||
// swaMemorySize int32
|
||||
|
||||
// chunkSize int32
|
||||
|
||||
// opts CausalOptions
|
||||
|
||||
// // maxBatch is the largest batch that we might receive
|
||||
// maxBatch int
|
||||
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // size of the current batch
|
||||
// curBatchSize int
|
||||
|
||||
// // locations for data storage for this batch
|
||||
// curLoc ml.Tensor
|
||||
|
||||
// // mask of the cache as used by this batch
|
||||
// curMask ml.Tensor
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // locations in the cache that are needed for this batch
|
||||
// curCellRange cellRange
|
||||
|
||||
// // curSequences is the sequences corresponding to this pass's entries in the cache
|
||||
// curSequences []int
|
||||
|
||||
// // curPositions is the positions corresponding to this pass's entries in the cache
|
||||
// curPositions []int32
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // for each possible location in the cache, stores the position and set of sequences
|
||||
// // that reference the data there
|
||||
// cells []cacheCell
|
||||
|
||||
// // maps from sequence to the range of locations where it is stored in the cache
|
||||
// cellRanges map[int]cellRange
|
||||
|
||||
// // ** cache data storage **
|
||||
|
||||
// shiftFn shiftFn
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
|
||||
// kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
// }
|
||||
|
||||
// type cacheCell struct {
|
||||
// pos int32
|
||||
// sequences []int
|
||||
// }
|
||||
|
||||
// type cellRange struct {
|
||||
// min int
|
||||
// max int
|
||||
// }
|
||||
|
||||
// func NewCausalCache(shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// swaWindowSize: windowSize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// swaWindowSize: windowSize,
|
||||
// swaMemorySize: memorySize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// chunkSize: chunkSize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding == 0 {
|
||||
// c.config.CachePadding = 1
|
||||
// }
|
||||
|
||||
// if c.config.MaskBatchPadding == 0 {
|
||||
// c.config.MaskBatchPadding = 1
|
||||
// }
|
||||
|
||||
// // TODO what types do we handle here?
|
||||
// // if c.config.MaskDType == ml.DTypeOther {
|
||||
// // c.config.MaskDType = ml.DTypeFloat32
|
||||
// // }
|
||||
|
||||
// if c.swaWindowSize == 0 {
|
||||
// c.swaWindowSize = math.MaxInt32
|
||||
// }
|
||||
// if c.swaMemorySize == 0 {
|
||||
// c.swaMemorySize = c.swaWindowSize
|
||||
// }
|
||||
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
|
||||
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
|
||||
// // causing a cache break. As an optimization, only do this when we have parallel sequences
|
||||
// // because the extra token will live in the batch buffer and won't get overwritten if we
|
||||
// // only have a single sequence.
|
||||
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
||||
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
||||
// }
|
||||
// if int(c.swaMemorySize) >= capacity {
|
||||
// c.swaMemorySize = math.MaxInt32
|
||||
// }
|
||||
|
||||
// if c.swaMemorySize < c.swaWindowSize {
|
||||
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
||||
// }
|
||||
|
||||
// var cacheSize int
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// cacheSize = maxSequences * capacity
|
||||
// } else {
|
||||
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
||||
// }
|
||||
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
// c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
// c.DType = dtype
|
||||
// c.cellRanges = make(map[int]cellRange)
|
||||
// c.backend = backend
|
||||
// c.maxBatch = maxBatch
|
||||
// }
|
||||
|
||||
// func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
|
||||
// // panic("XXX Causal.StartForward")
|
||||
// c.curBatchSize = len(batch.Positions)
|
||||
// c.curSequences = batch.Sequences
|
||||
// c.curPositions = batch.Positions
|
||||
// c.opts.Except = nil
|
||||
|
||||
// var locs []int32
|
||||
// if !reserve {
|
||||
// c.updateSlidingWindow()
|
||||
|
||||
// var err error
|
||||
// locs, err = c.findLocs()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
|
||||
|
||||
// for i, pos := range batch.Positions {
|
||||
// seq := batch.Sequences[i]
|
||||
// loc := int(locs[i])
|
||||
|
||||
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
// seqRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// seqRange = newRange()
|
||||
// }
|
||||
|
||||
// seqRange.min = min(seqRange.min, loc)
|
||||
// c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||
|
||||
// seqRange.max = max(seqRange.max, loc)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||
|
||||
// c.cellRanges[seq] = seqRange
|
||||
// }
|
||||
// } else {
|
||||
// // If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// // to the worst case.
|
||||
// locs = make([]int32, c.curBatchSize)
|
||||
// for i := range locs {
|
||||
// locs[i] = int32(i)
|
||||
// }
|
||||
// c.curCellRange.min = 0
|
||||
// c.curCellRange.max = len(c.cells) - 1
|
||||
// }
|
||||
|
||||
// // XXX Building up the locs for what's already processed (if any)
|
||||
// dummyLocs := []int{}
|
||||
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
// for i := range c.curBatchSize {
|
||||
// enabled := !slices.Contains(c.opts.Except, i)
|
||||
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
// } else {
|
||||
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
|
||||
// dummyLocs = append(dummyLocs, i)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
|
||||
|
||||
// slog.Info("XXX Causal.StartForward", "locs", locs)
|
||||
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||
// c.curMask = c.buildMask(ctx)
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func newRange() cellRange {
|
||||
// return cellRange{
|
||||
// min: math.MaxInt,
|
||||
// max: 0,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Returns a slice of locations where each token in the batch should be stored
|
||||
// func (c *Causal) findLocs() ([]int32, error) {
|
||||
// loc := make([]int32, 0, c.curBatchSize)
|
||||
|
||||
// for i := range c.cells {
|
||||
// if len(c.cells[i].sequences) == 0 {
|
||||
// loc = append(loc, int32(i))
|
||||
// if len(loc) >= c.curBatchSize {
|
||||
// return loc, nil
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
// }
|
||||
|
||||
// func (c *Causal) updateSlidingWindow() {
|
||||
// c.curCellRange = newRange()
|
||||
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// for _, seq := range c.curSequences {
|
||||
// if seqRange, ok := c.cellRanges[seq]; ok {
|
||||
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
||||
// }
|
||||
// }
|
||||
|
||||
// return
|
||||
// }
|
||||
|
||||
// type lowestPosition struct {
|
||||
// pos int32
|
||||
// curBatch bool
|
||||
// }
|
||||
|
||||
// // create a map of unique sequences to the lowest position in that sequence
|
||||
// lowestPos := make(map[int]lowestPosition)
|
||||
// for i := range c.curPositions {
|
||||
// seq := c.curSequences[i]
|
||||
|
||||
// lowest, ok := lowestPos[seq]
|
||||
// if !ok {
|
||||
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
||||
// } else if c.curPositions[i] < lowest.pos {
|
||||
// lowest.pos = c.curPositions[i]
|
||||
// }
|
||||
|
||||
// lowestPos[seq] = lowest
|
||||
// }
|
||||
|
||||
// // for any sequences are not part of this batch, clean up any tokens
|
||||
// // that are no longer needed after the processing of the previous
|
||||
// // batch
|
||||
// for seq, seqRange := range c.cellRanges {
|
||||
// if _, ok := lowestPos[seq]; !ok {
|
||||
// var last int32
|
||||
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// last = max(last, c.cells[i].pos)
|
||||
// }
|
||||
// }
|
||||
|
||||
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
||||
// }
|
||||
// }
|
||||
|
||||
// // delete any entries that are beyond the window of the oldest position in the sequence
|
||||
// for seq, lowest := range lowestPos {
|
||||
// oldRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// newRange := newRange()
|
||||
|
||||
// for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
// } else {
|
||||
// newRange.min = min(newRange.min, i)
|
||||
// newRange.max = max(newRange.max, i)
|
||||
// }
|
||||
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
||||
// c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.cellRanges[seq] = newRange
|
||||
// }
|
||||
// }
|
||||
|
||||
// func roundDown(length, pad int) int {
|
||||
// return (length / pad) * pad
|
||||
// }
|
||||
|
||||
// func roundUp(length, pad int) int {
|
||||
// return ((length + pad - 1) / pad) * pad
|
||||
// }
|
||||
|
||||
// // Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// // token in the history should apply. This is based on both the sequence and causality (the
|
||||
// // position of the history is not ahead of the token in the batch).
|
||||
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
// // Align and pad the two dimensions as required by the backend
|
||||
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
// length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
// mask := make([]float32, batchSize*length)
|
||||
|
||||
// for i := range c.curBatchSize {
|
||||
// enabled := !slices.Contains(c.opts.Except, i)
|
||||
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// // has already been masked out because the sequence doesn't match.
|
||||
// for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
// mask[i] = float32(math.Inf(-1))
|
||||
// }
|
||||
|
||||
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
|
||||
|
||||
// // if c.config.MaskDType != ml.DTypeFloat32 {
|
||||
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
// // }
|
||||
|
||||
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
|
||||
|
||||
// return maskTensor
|
||||
// }
|
||||
|
||||
// func (c *Causal) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// type CausalOptions struct {
|
||||
// // Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||
// Except []int
|
||||
// }
|
||||
|
||||
// // SetCausal disables causal mask generation for a particular range of indicies in
|
||||
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
// if !slices.Equal(c.opts.Except, opts.Except) {
|
||||
// c.opts = opts
|
||||
// if ctx != nil {
|
||||
// c.curMask = c.buildMask(ctx)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// key := c.keys[c.curLayer]
|
||||
// value := c.values[c.curLayer]
|
||||
|
||||
// kHeadDim := c.kHeadDims[c.curLayer]
|
||||
// vHeadDim := c.vHeadDims[c.curLayer]
|
||||
// numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// // rowSize := numKVHeads * c.curBatchSize
|
||||
// // cachedSize := c.curMask.Dim(1)
|
||||
// cachedSize := c.curLoc.Dim(0)
|
||||
// // kCellSize := kHeadDim * numKVHeads
|
||||
// // vCellSize := vHeadDim * numKVHeads
|
||||
|
||||
// slog.Info("XXX Causal.Get full cache", "key", key)
|
||||
// slog.Info("XXX Causal.Get full cache", "value", value)
|
||||
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
|
||||
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
|
||||
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
|
||||
// // panic("XXX")
|
||||
|
||||
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||
// // panic("full cache value")
|
||||
|
||||
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||
|
||||
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
|
||||
// // panic("XXX")
|
||||
|
||||
// // if c.config.PermutedV {
|
||||
// // panic("permuted")
|
||||
// // // TODO not converted
|
||||
// // vHeadDim := value.Dim(1)
|
||||
// // elemSize := value.Stride(2)
|
||||
|
||||
// // value = value.AsStrided(ctx,
|
||||
// // []int{numKVHeads, vHeadDim, cachedSize},
|
||||
// // []int{value.Stride(0), value.Stride(1)},
|
||||
// // elemSize*c.curCellRange.min,
|
||||
// // )
|
||||
// // } else {
|
||||
// // vHeadDim := c.vHeadDims[c.curLayer]
|
||||
// // rowSize := value.Stride(2)
|
||||
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
|
||||
// // panic("XXX")
|
||||
|
||||
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||
|
||||
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
|
||||
// // panic("XXX")
|
||||
|
||||
// // }
|
||||
|
||||
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
|
||||
// // // the 1 becomes trailing and messes up later operations
|
||||
// // // This isn't the right solution, but works around it...
|
||||
// // if c.curMask.Dim(1) == 1 {
|
||||
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
|
||||
// // }
|
||||
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||
// // fmt.Fprintln(os.Stderr, value.ToString())
|
||||
// // panic("XXX")
|
||||
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
|
||||
|
||||
// return key, value, c.curMask
|
||||
// }
|
||||
|
||||
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// kHeadDim := key.Dim(3)
|
||||
// vHeadDim := value.Dim(3)
|
||||
// numKVHeads := key.Dim(1)
|
||||
// batchSize := key.Dim(2)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
|
||||
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
|
||||
// // panic("XXX")
|
||||
|
||||
// if c.curBatchSize != batchSize {
|
||||
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
// }
|
||||
|
||||
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
|
||||
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
|
||||
// c.kHeadDims[c.curLayer] = kHeadDim
|
||||
// c.vHeadDims[c.curLayer] = vHeadDim
|
||||
// c.numKVHeads[c.curLayer] = numKVHeads
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// // if c.config.PermutedV {
|
||||
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
|
||||
// // } else {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
|
||||
// // }
|
||||
// }
|
||||
|
||||
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||
|
||||
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
|
||||
// // panic("XXX")
|
||||
// // curLoc := 0 // TODO c.curLoc is now a tensor
|
||||
// // kSize := numKVHeads * kHeadDim
|
||||
// // vSize := numKVHeads * vHeadDim
|
||||
// // start := []int{int(curLoc), 0}
|
||||
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
|
||||
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
|
||||
// // strides := []int{1, 1}
|
||||
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
|
||||
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
|
||||
|
||||
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
|
||||
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
|
||||
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
|
||||
// // panic("input value")
|
||||
|
||||
// // fmt.Fprintln(os.Stderr, t.ToString())
|
||||
// // panic("XXX")
|
||||
|
||||
// // if c.config.PermutedV {
|
||||
// // panic("permuted")
|
||||
// // // TODO not adjusted
|
||||
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
// // value = value.Transpose(ctx, 2, 0, 1, 3)
|
||||
|
||||
// // valueCache := c.values[c.curLayer]
|
||||
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
|
||||
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
|
||||
// // } else {
|
||||
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
|
||||
|
||||
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
|
||||
// // }
|
||||
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
|
||||
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
|
||||
// // panic("XXX")
|
||||
|
||||
// }
|
||||
|
||||
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// seqRange := newRange()
|
||||
|
||||
// for i := range c.cells {
|
||||
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
||||
// if slices.Contains(c.cells[i].sequences, dstSeq) {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
||||
// }
|
||||
|
||||
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
||||
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
||||
// if i < seqRange.min {
|
||||
// seqRange.min = i
|
||||
// }
|
||||
// if i > seqRange.max {
|
||||
// seqRange.max = i
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.cellRanges[dstSeq] = seqRange
|
||||
// }
|
||||
|
||||
// func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// seqRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// // for sliding window, check that the window of the new sequence is contained in
|
||||
// // the window of what we are storing
|
||||
// var first int32 = math.MaxInt32
|
||||
// var last int32 = -1
|
||||
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// first = min(first, c.cells[i].pos)
|
||||
// last = max(last, c.cells[i].pos)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if last == -1 {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
// return posWindowStart >= first && pos <= last+1
|
||||
// }
|
||||
|
||||
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
// if c.shiftFn == nil {
|
||||
// return ErrNotSupported
|
||||
// }
|
||||
|
||||
// seqRange := c.cellRanges[seq]
|
||||
|
||||
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||
// size := min(seqRange.max-start+1, c.maxBatch)
|
||||
// offsets := make([]int32, size)
|
||||
|
||||
// var batchFirst, batchLast int
|
||||
|
||||
// batchFirst = -1
|
||||
// for i := range offsets {
|
||||
// cell := c.cells[start+i]
|
||||
|
||||
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
// offsets[i] = offset
|
||||
// if batchFirst < 0 {
|
||||
// batchFirst = i
|
||||
// }
|
||||
// batchLast = i
|
||||
// }
|
||||
// }
|
||||
|
||||
// if batchFirst < 0 {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// offsets = offsets[batchFirst : batchLast+1]
|
||||
|
||||
// slog.Info("XXX Causal.shift creating new temporary context")
|
||||
// ctx := c.backend.NewContext()
|
||||
// kShift := ctx.Input().FromInts(offsets, len(offsets))
|
||||
|
||||
// for i, key := range c.keys {
|
||||
// if key == nil {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// kHeadDim := key.Dim(2)
|
||||
// numKVHeads := key.Dim(1)
|
||||
// rowSize := key.Stride(0)
|
||||
|
||||
// key = key.AsStrided(ctx,
|
||||
// []int{len(offsets), numKVHeads, kHeadDim},
|
||||
// []int{key.Stride(0), key.Stride(1)},
|
||||
// rowSize*(start+batchFirst),
|
||||
// )
|
||||
|
||||
// roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
// if err != nil {
|
||||
// ctx.Close()
|
||||
// return err
|
||||
// }
|
||||
|
||||
// ctx.Forward(roped.Copy(ctx, key))
|
||||
// }
|
||||
|
||||
// ctx.Compute()
|
||||
// ctx.Close()
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||
// // should return an error, which will trigger the runner to evaluate the full history and
|
||||
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||
// // results in use after free, so we don't do it for now.
|
||||
|
||||
// var offset int32
|
||||
// if endIndex != math.MaxInt32 {
|
||||
// offset = beginIndex - endIndex
|
||||
// }
|
||||
|
||||
// seqRange := newRange()
|
||||
|
||||
// for i := range c.cells {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
// } else {
|
||||
// if c.cells[i].pos >= endIndex {
|
||||
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||
// return errors.New("shifting cells shared by multiple sequences not supported")
|
||||
// }
|
||||
|
||||
// c.cells[i].pos += offset
|
||||
// }
|
||||
// if i < seqRange.min {
|
||||
// seqRange.min = i
|
||||
// }
|
||||
// if i > seqRange.max {
|
||||
// seqRange.max = i
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// if seqRange == newRange() {
|
||||
// delete(c.cellRanges, seq)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// c.cellRanges[seq] = seqRange
|
||||
|
||||
// if endIndex != math.MaxInt32 {
|
||||
// err := c.shift(seq, endIndex+offset, offset)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
@@ -1,973 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
// "math"
|
||||
// "slices"
|
||||
// "testing"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// type testCase struct {
|
||||
// name string
|
||||
// in []float32
|
||||
// inShape []int
|
||||
// seqs []int
|
||||
// pos []int32
|
||||
// expected []float32
|
||||
// expectedShape []int
|
||||
// expectedMask []float32
|
||||
// }
|
||||
|
||||
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
|
||||
// t.Helper()
|
||||
// for _, permuted := range []bool{false, true} {
|
||||
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
|
||||
// fn(t, &testBackend{permutedV: permuted})
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestStore(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
// inShape: []int{2, 3, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
// expectedShape: []int{2, 3, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{115, 215, 125, 225, 135, 235},
|
||||
// inShape: []int{2, 3, 1},
|
||||
// seqs: []int{0},
|
||||
// pos: []int32{4},
|
||||
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||
// expectedShape: []int{2, 3, 5},
|
||||
// expectedMask: []float32{0, 0, 0, 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWA(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWACache(1, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, 0, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{5, 6, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0,
|
||||
// 0, 0, x, x,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWASeparateBatches(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWACache(1, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "First seq 0",
|
||||
// in: []float32{1, 2},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{0, 1},
|
||||
// expected: []float32{1, 2},
|
||||
// expectedShape: []int{1, 1, 2},
|
||||
// expectedMask: []float32{
|
||||
// 0, x,
|
||||
// 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Second seq 0",
|
||||
// in: []float32{3, 4},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{2, 3},
|
||||
// expected: []float32{2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 3},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x,
|
||||
// x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "First seq 1",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{0, 1},
|
||||
// expected: []float32{5, 6},
|
||||
// expectedShape: []int{1, 1, 2},
|
||||
// expectedMask: []float32{
|
||||
// 0, x,
|
||||
// 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Second seq 1",
|
||||
// in: []float32{7, 8},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{2, 3},
|
||||
// expected: []float32{6, 3, 4, 7, 8},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0, x,
|
||||
// x, x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Third seq 0",
|
||||
// in: []float32{9, 10},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{9, 10, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0,
|
||||
// 0, 0, x, x,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWAMem(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWAMemCache(1, 3, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, 0, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{5, 2, 3, 4, 6},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0, x,
|
||||
// 0, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestChunkedAttention(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewChunkedAttentionCache(2, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// testCache(
|
||||
// t, backend, cache,
|
||||
// []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, x, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6, 7},
|
||||
// inShape: []int{1, 1, 3},
|
||||
// seqs: []int{0, 0, 0},
|
||||
// pos: []int32{4, 5, 6},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
||||
// expectedShape: []int{1, 1, 7},
|
||||
// expectedMask: []float32{
|
||||
// x, x, x, x, 0, x, x,
|
||||
// x, x, x, x, 0, 0, x,
|
||||
// x, x, x, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "ThirdBatch",
|
||||
// in: []float32{8, 9},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{7, 8},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
// expectedShape: []int{1, 1, 9},
|
||||
// expectedMask: []float32{
|
||||
// x, x, x, x, x, x, 0, 0, x,
|
||||
// x, x, x, x, x, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// )
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSequences(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 1, 1},
|
||||
// pos: []int32{0, 1, 0, 1},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 1},
|
||||
// pos: []int32{2, 2},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestRemove(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// return key.Add(ctx, shift), nil
|
||||
// })
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 1, 1},
|
||||
// pos: []int32{0, 1, 0, 1},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, x, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// err := cache.Remove(0, 1, math.MaxInt32)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "RemoveEnd",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 1},
|
||||
// pos: []int32{1, 2},
|
||||
// expected: []float32{1, 5, 3, 4, 6},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x, x, x,
|
||||
// x, x, 0, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// err = cache.Remove(0, 0, 1)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "RemoveMiddle",
|
||||
// in: []float32{7, 8},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{1, 2},
|
||||
// expected: []float32{7, 4, 3, 4, 6, 8},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x, x, x, x,
|
||||
// 0, 0, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestCopy(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// cache.CopyPrefix(0, 1, 2)
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "Copy",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{3, 4},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||
// for _, test := range tests {
|
||||
// t.Run(test.name, func(t *testing.T) {
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats(test.in, test.inShape...)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// out, _, mask := cache.Get(context)
|
||||
|
||||
// context.Forward(out, mask).Compute(out, mask)
|
||||
|
||||
// if !slices.Equal(out.Floats(), test.expected) {
|
||||
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
||||
// }
|
||||
|
||||
// if !slices.Equal(out.Shape(), test.expectedShape) {
|
||||
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
||||
// }
|
||||
|
||||
// if !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestCanResume(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// windowSize := int32(4)
|
||||
// cache := NewSWACache(windowSize, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{0, 1, 2, 3, 4},
|
||||
// Sequences: []int{0, 0, 0, 0, 0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // with window size 4, nothing has slid out of the window yet
|
||||
// if !cache.CanResume(0, 0) {
|
||||
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 1) {
|
||||
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 2) {
|
||||
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 3) {
|
||||
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 4) {
|
||||
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
||||
// }
|
||||
|
||||
// // shift window by adding position 5
|
||||
// err = cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{5},
|
||||
// Sequences: []int{0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // only the latest position has overlapping windows
|
||||
// if cache.CanResume(0, 0) {
|
||||
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 1) {
|
||||
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 2) {
|
||||
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 3) {
|
||||
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 4) {
|
||||
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 5) {
|
||||
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestCanResumeSWAMem(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// windowSize := int32(4)
|
||||
// memSize := int32(5)
|
||||
// cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
||||
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // shift window by adding position 7
|
||||
// err = cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{7},
|
||||
// Sequences: []int{0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // only the latest position has overlapping windows
|
||||
// if cache.CanResume(0, 0) {
|
||||
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 1) {
|
||||
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 2) {
|
||||
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 3) {
|
||||
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 4) {
|
||||
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 5) {
|
||||
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 6) {
|
||||
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 7) {
|
||||
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// type testBackend struct {
|
||||
// ml.Backend
|
||||
// permutedV bool
|
||||
// }
|
||||
|
||||
// func (b *testBackend) NewContext() ml.Context {
|
||||
// return &testContext{}
|
||||
// }
|
||||
|
||||
// func (b *testBackend) NewContextSize(int) ml.Context {
|
||||
// return &testContext{}
|
||||
// }
|
||||
|
||||
// func (b *testBackend) CacheConfig() ml.CacheConfig {
|
||||
// return ml.CacheConfig{PermutedV: b.permutedV}
|
||||
// }
|
||||
|
||||
// type testContext struct {
|
||||
// ml.Context
|
||||
// }
|
||||
|
||||
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
// total := 0
|
||||
|
||||
// if len(shape) > 0 {
|
||||
// total = 1
|
||||
// for _, s := range shape {
|
||||
// total *= s
|
||||
// }
|
||||
// }
|
||||
|
||||
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||
// }
|
||||
|
||||
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
// return c.Empty(dtype, shape...)
|
||||
// }
|
||||
|
||||
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
|
||||
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
// copy(t.data, s)
|
||||
|
||||
// return t
|
||||
// }
|
||||
|
||||
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
|
||||
// f := make([]float32, len(s))
|
||||
// for i := range f {
|
||||
// f[i] = float32(s[i])
|
||||
// }
|
||||
|
||||
// out := c.FromFloats(f, shape...)
|
||||
// out.(*testTensor).dtype = ml.DTypeI32
|
||||
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
// s := make([]float32, 0, int((stop-start)/step))
|
||||
// for i := start; i < stop; i += step {
|
||||
// s = append(s, i)
|
||||
// }
|
||||
|
||||
// out := c.FromFloats(s, len(s))
|
||||
// out.(*testTensor).dtype = dtype
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (c *testContext) Input() ml.Context { return c }
|
||||
// func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
// func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
// func (c *testContext) Reserve() {}
|
||||
|
||||
// func (c *testContext) MaxGraphNodes() int {
|
||||
// return 10
|
||||
// }
|
||||
|
||||
// func (c *testContext) Close() {}
|
||||
|
||||
// type testTensor struct {
|
||||
// ml.Tensor
|
||||
|
||||
// dtype ml.DType
|
||||
// elementSize int
|
||||
// data []float32
|
||||
// shape []int
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Dim(n int) int {
|
||||
// return t.shape[n]
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Stride(n int) int {
|
||||
// stride := t.elementSize
|
||||
// for i := range n {
|
||||
// stride *= t.shape[i]
|
||||
// }
|
||||
|
||||
// return stride
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Shape() []int {
|
||||
// return t.shape
|
||||
// }
|
||||
|
||||
// func (t *testTensor) DType() ml.DType {
|
||||
// return t.dtype
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Floats() []float32 {
|
||||
// out := make([]float32, len(t.data))
|
||||
// copy(out, t.data)
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
// for i := range out.data {
|
||||
// out.data[i] = -t.data[i]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
// for i := range out.data {
|
||||
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||
// }
|
||||
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
// return &testTensor{
|
||||
// dtype: t.dtype,
|
||||
// elementSize: t.elementSize,
|
||||
// data: t.data,
|
||||
// shape: shape,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
// offset /= t.elementSize
|
||||
|
||||
// var s []int
|
||||
|
||||
// switch len(shape) {
|
||||
// case 1:
|
||||
// s = []int{shape[0]}
|
||||
// case 3:
|
||||
// s = []int{shape[0], shape[2]}
|
||||
// case 5:
|
||||
// s = []int{shape[0], shape[2], shape[4]}
|
||||
// default:
|
||||
// panic("unsupported number of dimensions")
|
||||
// }
|
||||
|
||||
// context := &testContext{}
|
||||
|
||||
// view := context.Empty(t.dtype, s...).(*testTensor)
|
||||
// view.data = t.data[offset : offset+len(view.data)]
|
||||
|
||||
// return view
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||
// if len(t.shape) > 4 || len(order) > 4 {
|
||||
// panic("permute only supports up to 4 dimensions")
|
||||
// }
|
||||
|
||||
// if len(order) != len(t.shape) && len(order) != 4 {
|
||||
// panic("invalid number of dimensions for permute")
|
||||
// }
|
||||
|
||||
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
|
||||
// orderFull := append(make([]int, 0, 4), order...)
|
||||
// for len(orderFull) < 4 {
|
||||
// orderFull = append(orderFull, len(orderFull))
|
||||
// }
|
||||
|
||||
// seen := [4]bool{}
|
||||
|
||||
// shape4 := [4]int{1, 1, 1, 1}
|
||||
// for i := 0; i < len(t.shape) && i < 4; i++ {
|
||||
// shape4[i] = t.shape[i]
|
||||
// }
|
||||
|
||||
// newShape4 := [4]int{1, 1, 1, 1}
|
||||
// for axis := range 4 {
|
||||
// dst := orderFull[axis]
|
||||
// if dst < 0 || dst >= 4 {
|
||||
// panic("invalid axis for permute")
|
||||
// }
|
||||
// if seen[dst] {
|
||||
// panic("duplicate axis for permute")
|
||||
// }
|
||||
// seen[dst] = true
|
||||
// newShape4[dst] = shape4[axis]
|
||||
// }
|
||||
|
||||
// total := len(t.data)
|
||||
// newData := make([]float32, total)
|
||||
|
||||
// if total > 0 {
|
||||
// oldDims := shape4
|
||||
// newDims := newShape4
|
||||
|
||||
// oldStride := [4]int{1, 1, 1, 1}
|
||||
// newStride := [4]int{1, 1, 1, 1}
|
||||
// for i := 1; i < 4; i++ {
|
||||
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
|
||||
// newStride[i] = newStride[i-1] * newDims[i-1]
|
||||
// }
|
||||
|
||||
// var coords [4]int
|
||||
// var newCoords [4]int
|
||||
|
||||
// for idx := range total {
|
||||
// remainder := idx
|
||||
// for axis := range 4 {
|
||||
// dim := oldDims[axis]
|
||||
// if dim == 0 {
|
||||
// coords[axis] = 0
|
||||
// continue
|
||||
// }
|
||||
// coords[axis] = remainder % dim
|
||||
// remainder /= dim
|
||||
// }
|
||||
|
||||
// for axis := range 4 {
|
||||
// newCoords[orderFull[axis]] = coords[axis]
|
||||
// }
|
||||
|
||||
// newIndex := 0
|
||||
// for axis := range 4 {
|
||||
// if newDims[axis] == 0 {
|
||||
// continue
|
||||
// }
|
||||
// newIndex += newCoords[axis] * newStride[axis]
|
||||
// }
|
||||
|
||||
// newData[newIndex] = t.data[idx]
|
||||
// }
|
||||
// }
|
||||
|
||||
// numDims := 4
|
||||
// for numDims > 1 && newShape4[numDims-1] <= 1 {
|
||||
// numDims--
|
||||
// }
|
||||
|
||||
// newShape := make([]int, numDims)
|
||||
// copy(newShape, newShape4[:numDims])
|
||||
|
||||
// return &testTensor{
|
||||
// dtype: t.dtype,
|
||||
// elementSize: t.elementSize,
|
||||
// data: newData,
|
||||
// shape: newShape,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||
// dst := t
|
||||
// srcTensor := src.(*testTensor)
|
||||
// idxTensor := idxs.(*testTensor)
|
||||
|
||||
// shapeTo4D := func(shape []int) [4]int {
|
||||
// out := [4]int{1, 1, 1, 1}
|
||||
// for i := 0; i < len(shape) && i < 4; i++ {
|
||||
// out[i] = shape[i]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// computeStrides := func(shape [4]int) [4]int {
|
||||
// out := [4]int{1, 1, 1, 1}
|
||||
// for i := 1; i < 4; i++ {
|
||||
// out[i] = out[i-1] * shape[i-1]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// dstShape4D := shapeTo4D(dst.shape)
|
||||
// srcShape4D := shapeTo4D(srcTensor.shape)
|
||||
// idxShape4D := shapeTo4D(idxTensor.shape)
|
||||
|
||||
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
|
||||
// panic("SetRows requires matching tensor shapes")
|
||||
// }
|
||||
|
||||
// if srcShape4D[1] != idxShape4D[0] {
|
||||
// panic("SetRows rows/index mismatch")
|
||||
// }
|
||||
|
||||
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
|
||||
// panic("SetRows cannot broadcast indices")
|
||||
// }
|
||||
|
||||
// if idxShape4D[3] != 1 {
|
||||
// panic("SetRows expects 1D or 2D index tensors")
|
||||
// }
|
||||
|
||||
// dstStride := computeStrides(dstShape4D)
|
||||
// srcStride := computeStrides(srcShape4D)
|
||||
// idxStride := computeStrides(idxShape4D)
|
||||
|
||||
// numColumns := srcShape4D[0]
|
||||
// numRows := srcShape4D[1]
|
||||
|
||||
// for dim3Index := range dstShape4D[3] {
|
||||
// for dim2Index := range dstShape4D[2] {
|
||||
// idxDim2 := 0
|
||||
// idxDim3 := 0
|
||||
// if idxShape4D[1] > 0 {
|
||||
// idxDim2 = dim2Index % idxShape4D[1]
|
||||
// }
|
||||
// if idxShape4D[2] > 0 {
|
||||
// idxDim3 = dim3Index % idxShape4D[2]
|
||||
// }
|
||||
|
||||
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
|
||||
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
|
||||
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
|
||||
|
||||
// for row := range numRows {
|
||||
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
|
||||
// if idx < 0 || idx >= dstShape4D[1] {
|
||||
// panic("SetRows index out of range")
|
||||
// }
|
||||
|
||||
// srcOffset := srcBase + row*srcStride[1]
|
||||
// dstOffset := dstBase + idx*dstStride[1]
|
||||
|
||||
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return dst
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
// copy(t2.(*testTensor).data, t.data)
|
||||
// return nil
|
||||
// }
|
||||
@@ -1,156 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Encoder cache stores K and V tensors that are position independent
|
||||
// //
|
||||
// // The tensors can be of any shape and will be returned as they were stored
|
||||
// // The mask is currently always nil
|
||||
// //
|
||||
// // Not currently safe for multiple sequences
|
||||
// type EncoderCache struct {
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // if something is stored during this pass, this
|
||||
// // will be the position (but there is no guarantee
|
||||
// // anything will be stored)
|
||||
// curPos int32
|
||||
|
||||
// // curReserve indicates that this forward pass is only for
|
||||
// // memory reservation and we should not update our metadata
|
||||
// // based on it.
|
||||
// curReserve bool
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // was something stored in the cache?
|
||||
// encoderCached bool
|
||||
|
||||
// // position of the cached data
|
||||
// encoderPos int32
|
||||
|
||||
// // ** cache data storage **
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
// }
|
||||
|
||||
// func NewEncoderCache() *EncoderCache {
|
||||
// return &EncoderCache{
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if maxSequences > 1 {
|
||||
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
// }
|
||||
|
||||
// c.backend = backend
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Close() {
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// // We work with the most recent image
|
||||
// if len(batch.Multimodal) > 0 {
|
||||
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
// }
|
||||
|
||||
// c.curReserve = reserve
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) EncoderCached() bool {
|
||||
// return c.encoderCached
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// if !c.curReserve {
|
||||
// c.encoderPos = c.curPos
|
||||
// c.encoderCached = true
|
||||
// }
|
||||
|
||||
// if c.config.PermutedV {
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3)
|
||||
// }
|
||||
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||
// }
|
||||
|
||||
// ctx.Forward(
|
||||
// key.Copy(ctx, c.keys[c.curLayer]),
|
||||
// value.Copy(ctx, c.values[c.curLayer]),
|
||||
// )
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// panic("encoder cache does not support multiple sequences")
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
// c.encoderCached = false
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
144
x/kvcache/mlx.go
@@ -1,144 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type MLXCausal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewMLXCausalCache() *MLXCausal {
|
||||
return &MLXCausal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *MLXCausal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Close() {
|
||||
// slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX MLXCausal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *MLXCausal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "math"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Wrapper cache is a container for multiple types of caches,
|
||||
// // such as for the encoding and decoding portions of a model.
|
||||
// type WrapperCache struct {
|
||||
// // caches we are wrapping
|
||||
// caches []Cache
|
||||
|
||||
// // cache to be used for this layer
|
||||
// curType int
|
||||
// }
|
||||
|
||||
// func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||
// return &WrapperCache{
|
||||
// caches: caches,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetConfig(config)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Close() {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// for i, cache := range c.caches {
|
||||
// err := cache.StartForward(ctx, batch, reserve)
|
||||
// if err != nil {
|
||||
// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
// for j := i - 1; j >= 0; j-- {
|
||||
// for k := range batch.Positions {
|
||||
// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||
// }
|
||||
// }
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.curType = 0
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayer(layer int) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.SetLayer(layer)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) SetLayerType(layerType int) {
|
||||
// c.curType = layerType
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) UnderlyingCache() Cache {
|
||||
// return c.caches[c.curType]
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.caches[c.curType].Get(ctx)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// c.caches[c.curType].Put(ctx, key, value)
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// for _, cache := range c.caches {
|
||||
// cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||
// for _, cache := range c.caches {
|
||||
// if !cache.CanResume(seq, pos) {
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
// for _, cache := range c.caches {
|
||||
// err := cache.Remove(seq, beginIndex, endIndex)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
433
x/ml/backend.go
@@ -1,433 +0,0 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type Backend interface {
|
||||
// Close frees all memory associated with this backend
|
||||
// Close()
|
||||
|
||||
// Load(ctx context.Context, progress func(float32)) error
|
||||
|
||||
// BackendMemory returns the memory allocations that were made for this model
|
||||
// BackendMemory() BackendMemory
|
||||
|
||||
Config() fs.Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
// NewContextSize(size int) Context
|
||||
|
||||
// Enumerate the devices available for inference via this backend
|
||||
// BackendDevices() []DeviceInfo
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// AllocMemory causes the backend to allocate memory for the model. If
|
||||
// false, this is only being used for discovering the required amount of
|
||||
// memory and cannot load the model for running.
|
||||
AllocMemory bool
|
||||
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
NumThreads int
|
||||
|
||||
// GPULayers is the set of layers to offload to GPUs
|
||||
GPULayers GPULayersList
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||
be := os.Getenv("OLLAMA_BACKEND")
|
||||
if be == "" {
|
||||
be = "mlx"
|
||||
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
|
||||
}
|
||||
slog.Info("Loading new engine", "backend", be)
|
||||
if backend, ok := backends[be]; ok {
|
||||
return backend(modelPath, params)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
|
||||
FromFloats(s []float32, shape ...int) Tensor
|
||||
FromInts(s []int32, shape ...int) Tensor
|
||||
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
|
||||
|
||||
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
||||
Arange(start, stop, step float32, dtype DType) Tensor
|
||||
|
||||
Forward(...Tensor) Context
|
||||
|
||||
// SetBatchSize provides a hint on the batch size to optimize processing
|
||||
// Uses heuristics if not set
|
||||
// SetBatchSize(int)
|
||||
|
||||
Compute(...Tensor)
|
||||
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
// Reserve()
|
||||
|
||||
// MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating tensors that are
|
||||
// inputs to the model (which includes things like output locations)
|
||||
Input() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
|
||||
// Load a tensor from "filename" safetensors file, and compare with the input tensor
|
||||
// Returns error if the shape is inconsistent, or similarity measures are below 99%
|
||||
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
|
||||
}
|
||||
|
||||
type RoPEOptions struct {
|
||||
Base *float32
|
||||
Freqs Tensor
|
||||
}
|
||||
|
||||
func WithRoPEBase(base float32) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Base = &base
|
||||
}
|
||||
}
|
||||
|
||||
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
|
||||
return func(opts *RoPEOptions) {
|
||||
opts.Freqs = freqs
|
||||
}
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
ToString() string
|
||||
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
|
||||
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
|
||||
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
|
||||
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
|
||||
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
|
||||
Shape() []int
|
||||
DType() DType
|
||||
// Cast(ctx Context, dtype DType) Tensor
|
||||
|
||||
// Bytes() []byte
|
||||
Floats() []float32
|
||||
Ints() []int32
|
||||
|
||||
// FromBytes([]byte)
|
||||
// FromFloats([]float32)
|
||||
// FromInts([]int32)
|
||||
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
// Mul(ctx Context, t2 Tensor) Tensor
|
||||
// Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Max(ctx Context, axes []int, keepDims bool) Tensor
|
||||
Min(ctx Context, axes []int, keepDims bool) Tensor
|
||||
|
||||
Matmul(ctx Context, a2 Tensor) Tensor
|
||||
// Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
// MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||
// AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
// SumRows(ctx Context) Tensor
|
||||
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
|
||||
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
|
||||
|
||||
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
// Sin(ctx Context) Tensor
|
||||
// Cos(ctx Context) Tensor
|
||||
// Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
// QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
// SILU(ctx Context, up ...Tensor) Tensor
|
||||
// RELU(ctx Context, up ...Tensor) Tensor
|
||||
// Sigmoid(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
|
||||
Transpose(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context, allowColMajor bool) Tensor
|
||||
|
||||
// Pad(ctx Context, shape ...int) Tensor
|
||||
|
||||
// Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
|
||||
// Repeat repeats the tensor n times along dimension dim
|
||||
// Repeat(ctx Context, dim, n int) Tensor
|
||||
// Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
// Rows(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
// TODO these probably aren't actually needed - false starts on trying to wire up cache
|
||||
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
|
||||
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
|
||||
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
|
||||
|
||||
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
|
||||
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
// Duplicate(ctx Context) Tensor
|
||||
|
||||
// Slice(ctx Context, dim, low, high, step int) Tensor
|
||||
// Chunk(ctx Context, dim int, size int) []Tensor
|
||||
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||
|
||||
// TopK(ctx Context, k int) Tensor
|
||||
// Argsort(ctx Context) Tensor
|
||||
// Mean(ctx Context) Tensor
|
||||
// Variance(ctx Context) Tensor
|
||||
// Stddev(ctx Context) Tensor
|
||||
// Sqr(ctx Context) Tensor
|
||||
// Sqrt(ctx Context) Tensor
|
||||
|
||||
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
// type ScaledDotProductAttention interface {
|
||||
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||
// }
|
||||
|
||||
// type number interface {
|
||||
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
// ~float32 | ~float64 |
|
||||
// ~complex64 | ~complex128
|
||||
// }
|
||||
|
||||
// func mul[T number](s ...T) T {
|
||||
// p := T(1)
|
||||
// for _, v := range s {
|
||||
// p *= v
|
||||
// }
|
||||
|
||||
// return p
|
||||
// }
|
||||
|
||||
// type DumpOptions func(*dumpOptions)
|
||||
|
||||
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
||||
// func DumpWithPrecision(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Precision = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
||||
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
||||
// // beginning and end of each dimension will be printed.
|
||||
// func DumpWithThreshold(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.Threshold = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
||||
// func DumpWithEdgeItems(n int) DumpOptions {
|
||||
// return func(opts *dumpOptions) {
|
||||
// opts.EdgeItems = n
|
||||
// }
|
||||
// }
|
||||
|
||||
// type dumpOptions struct {
|
||||
// Precision, Threshold, EdgeItems int
|
||||
// }
|
||||
|
||||
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
||||
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
||||
// for _, optsFunc := range optsFuncs {
|
||||
// optsFunc(&opts)
|
||||
// }
|
||||
|
||||
// if mul(t.Shape()...) <= opts.Threshold {
|
||||
// opts.EdgeItems = math.MaxInt
|
||||
// }
|
||||
|
||||
// switch t.DType() {
|
||||
// case DTypeFloat32:
|
||||
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeFloat16: // TODO other types...
|
||||
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
|
||||
// f32 = t.Copy(ctx, f32)
|
||||
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
||||
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
// })
|
||||
// case DTypeInt32:
|
||||
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
||||
// return strconv.FormatInt(int64(i), 10)
|
||||
// })
|
||||
// default:
|
||||
// return "<unsupported>"
|
||||
// }
|
||||
// }
|
||||
|
||||
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
// if t.Bytes() == nil {
|
||||
// ctx.Compute(t)
|
||||
// }
|
||||
|
||||
// s := make(S, mul(t.Shape()...))
|
||||
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// shape := t.Shape()
|
||||
// slices.Reverse(shape)
|
||||
|
||||
// var sb strings.Builder
|
||||
// var f func([]int, int)
|
||||
// f = func(dims []int, stride int) {
|
||||
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||
// sb.WriteString("[")
|
||||
// defer func() { sb.WriteString("]") }()
|
||||
// for i := 0; i < dims[0]; i++ {
|
||||
// if i >= items && i < dims[0]-items {
|
||||
// sb.WriteString("..., ")
|
||||
// // skip to next printable element
|
||||
// skip := dims[0] - 2*items
|
||||
// if len(dims) > 1 {
|
||||
// stride += mul(append(dims[1:], skip)...)
|
||||
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// i += skip - 1
|
||||
// } else if len(dims) > 1 {
|
||||
// f(dims[1:], stride)
|
||||
// stride += mul(dims[1:]...)
|
||||
// if i < dims[0]-1 {
|
||||
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||
// }
|
||||
// } else {
|
||||
// text := fn(s[stride+i])
|
||||
// if len(text) > 0 && text[0] != '-' {
|
||||
// sb.WriteString(" ")
|
||||
// }
|
||||
|
||||
// sb.WriteString(text)
|
||||
// if i < dims[0]-1 {
|
||||
// sb.WriteString(", ")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// f(shape, 0)
|
||||
|
||||
// return sb.String()
|
||||
// }
|
||||
|
||||
type DType int
|
||||
|
||||
const (
|
||||
DTypeBool DType = iota
|
||||
DTypeUint8
|
||||
DTypeUint16
|
||||
DTypeUint32
|
||||
DTypeUint64
|
||||
DTypeInt8
|
||||
DTypeInt16
|
||||
DTypeInt32
|
||||
DTypeInt64
|
||||
DTypeFloat16
|
||||
DTypeFloat32
|
||||
DTypeFloat64
|
||||
DTypeBfloat16
|
||||
DTypeComplex64
|
||||
)
|
||||
|
||||
type SamplingMode int
|
||||
|
||||
const (
|
||||
SamplingModeNearest SamplingMode = iota
|
||||
SamplingModeBilinear
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
package backend
|
||||
|
||||
// _ "github.com/ollama/ollama/x/ml/backend/mlx"
|
||||
@@ -1,57 +0,0 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
set(MLX_BUILD_SAFETENSORS ON)
|
||||
|
||||
function(set_target_output_directory _target)
|
||||
if(TARGET ${_target})
|
||||
set_target_properties(${_target} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Check for Metal support (macOS only)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(NOT MLX_METAL_VERSION)
|
||||
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
else()
|
||||
# On Linux, disable Metal backend
|
||||
message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set
|
||||
if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
|
||||
message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available
|
||||
if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER)
|
||||
set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE)
|
||||
message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
elseif(MLX_CUDA_ARCHITECTURES)
|
||||
message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled")
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG v0.4.1)
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
set_target_output_directory(mlxc)
|
||||
@@ -1,314 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
_ "github.com/ollama/ollama/x/model/models/gemma3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
}
|
||||
|
||||
func TestLoadModel(t *testing.T) {
|
||||
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||
b := &Backend{}
|
||||
err := b.LoadSafeTensors(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("load failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromInts(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
data := []int32{1, 2, 3, 4, 5, 6}
|
||||
a := c.FromInts(data, 2, 3)
|
||||
slog.Info("", "array", a)
|
||||
t.Log(a.ToString())
|
||||
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromFloats(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
data := []float32{1, 2, 3, 4, 5, 6}
|
||||
a := c.FromFloats(data, 2, 3)
|
||||
slog.Info("", "array", a)
|
||||
t.Log(a.ToString())
|
||||
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
||||
t.Fatalf("incorrect shape: %v", a.Shape())
|
||||
}
|
||||
res := a.Floats()
|
||||
if !reflect.DeepEqual(res, data) {
|
||||
t.Fatalf("incorrect results: %v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
||||
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
|
||||
t3 := t1.Add(c, t2)
|
||||
c.Compute(t3, exp)
|
||||
t3f := t3.Floats()
|
||||
if !reflect.DeepEqual(t3f, exp.Floats()) {
|
||||
t.Fatalf("incorrect result: %v", t3f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReshapeTranspose(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
|
||||
c.Compute(t1)
|
||||
t1f := t1.Floats()
|
||||
exp := []float32{
|
||||
0, 4, 8,
|
||||
1, 5, 9,
|
||||
2, 6, 10,
|
||||
3, 7, 11,
|
||||
12, 16, 20,
|
||||
13, 17, 21,
|
||||
14, 18, 22,
|
||||
15, 19, 23,
|
||||
}
|
||||
if !reflect.DeepEqual(t1f, exp) {
|
||||
t.Fatalf("incorrect results: %v", t1f)
|
||||
}
|
||||
}
|
||||
|
||||
func prod(vals ...int) int {
|
||||
r := 1
|
||||
for _, v := range vals {
|
||||
r *= v
|
||||
}
|
||||
return r
|
||||
}
|
||||
func TestMatmul(t *testing.T) {
|
||||
// TODO create scenarios...
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
s1 := []int{1, 3, 2, 4}
|
||||
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
|
||||
s2 := []int{4, 2}
|
||||
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
|
||||
t3 := t1.Matmul(c, t2)
|
||||
exp := []float32{
|
||||
28, 34,
|
||||
76, 98,
|
||||
|
||||
124, 162,
|
||||
172, 226,
|
||||
|
||||
220, 290,
|
||||
268, 354,
|
||||
}
|
||||
c.Compute(t3)
|
||||
t3f := t3.Floats()
|
||||
if !reflect.DeepEqual(t3f, exp) {
|
||||
t.Fatalf("incorrect result: %v", t3f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRows(t *testing.T) {
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
|
||||
outputs := c.Zeros(ml.DTypeInt32, 1)
|
||||
t2 := t1.TakeAxes(c, outputs, 1)
|
||||
c.Forward(t1, t2).Compute(t1, t2)
|
||||
t.Log(t1.ToString())
|
||||
t.Log(t2.ToString())
|
||||
f := t2.Floats()
|
||||
t.Logf("Result: %v", f)
|
||||
}
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
// Validate the caching algorithm
|
||||
b := &Backend{}
|
||||
c := b.NewContext()
|
||||
defer c.Close()
|
||||
batchSize := 3
|
||||
headDim := 4
|
||||
numKVHeads := 2
|
||||
// Make cache twice the size of one test batch
|
||||
cells := batchSize * 2
|
||||
cellSize := numKVHeads * headDim
|
||||
shape := []int{1, numKVHeads, batchSize, headDim}
|
||||
stop := float32(1)
|
||||
for _, x := range shape {
|
||||
stop *= float32(x)
|
||||
}
|
||||
// Create the cache
|
||||
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
|
||||
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
|
||||
|
||||
// Input tensor
|
||||
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
|
||||
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
|
||||
|
||||
// Reshape to copy into the cache
|
||||
/*
|
||||
From MLX python/src/indexing.cpp mlx_scatter_args_array
|
||||
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||
auto up_shape = indices.shape();
|
||||
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
up = reshape(up, up_shape);
|
||||
*/
|
||||
numRows := 3
|
||||
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
|
||||
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
|
||||
|
||||
// Simulate cells 1,3,5 are available
|
||||
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
|
||||
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
|
||||
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
|
||||
cache.Scatter(c, indicies, up, axis)
|
||||
|
||||
c.Forward(cache)
|
||||
// Cache should contain the data now
|
||||
t.Log("Cache after put\n" + cache.ToString())
|
||||
|
||||
// Retrieve cache content and verify it matches
|
||||
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
|
||||
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
|
||||
|
||||
t1f := t1.Floats()
|
||||
outf := out.Floats()
|
||||
if !reflect.DeepEqual(t1f, outf) {
|
||||
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGemma3(t *testing.T) {
|
||||
// Why is the sky blue
|
||||
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
|
||||
limit := 50
|
||||
|
||||
// TODO generalize this
|
||||
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
||||
|
||||
m, err := model.New(dir, ml.BackendParams{})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to load model: %s", err)
|
||||
}
|
||||
b := m.Backend()
|
||||
ctx := b.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
batch := input.Batch{
|
||||
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
|
||||
Positions: make([]int32, len(inputs)),
|
||||
Sequences: make([]int, len(inputs)),
|
||||
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
|
||||
Offset: 0,
|
||||
}
|
||||
for i := range len(inputs) {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
offset := len(inputs)
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
numSlots := 1
|
||||
batchSize := 512
|
||||
numCtx := 4096
|
||||
|
||||
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
|
||||
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
|
||||
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
|
||||
|
||||
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed cache.StartForward: %s", err)
|
||||
}
|
||||
}
|
||||
opts := api.DefaultOptions()
|
||||
var grammar *sample.GrammarSampler
|
||||
sampler := sample.NewSampler(
|
||||
opts.Temperature,
|
||||
opts.TopK,
|
||||
opts.TopP,
|
||||
opts.MinP,
|
||||
opts.Seed,
|
||||
grammar,
|
||||
)
|
||||
|
||||
t.Log("Starting Forward pass loop")
|
||||
pendingResponses := []string{}
|
||||
for {
|
||||
out, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
t.Fatalf("failed forward pass: %s", err)
|
||||
}
|
||||
ctx.Forward(out)
|
||||
outputs := out.Floats()
|
||||
t.Logf("finished forward pass! length:%d", len(outputs))
|
||||
// sample a token
|
||||
logits := outputs
|
||||
token, err := sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to sample token: %s", err)
|
||||
}
|
||||
t.Logf("Sampled token: %v", token)
|
||||
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
t.Log("hit EOS")
|
||||
break
|
||||
}
|
||||
piece, err := m.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode token: %s", err)
|
||||
}
|
||||
|
||||
pendingResponses = append(pendingResponses, piece)
|
||||
sequence := strings.Join(pendingResponses, "")
|
||||
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
|
||||
t.Logf("hit stop token: %v", stop)
|
||||
break
|
||||
}
|
||||
t.Logf("RESULTS: %s", sequence)
|
||||
batch = input.Batch{
|
||||
Inputs: ctx.FromInts([]int32{token}, 1, 1),
|
||||
Positions: make([]int32, 1),
|
||||
Sequences: make([]int, 1),
|
||||
Outputs: ctx.FromInts([]int32{0}, 1),
|
||||
Offset: offset,
|
||||
}
|
||||
offset++
|
||||
batch.Positions[0] = 0
|
||||
err = cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed cache.StartForward: %s", err)
|
||||
}
|
||||
if offset > limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,335 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/ops.h"
|
||||
|
||||
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
|
||||
|
||||
void unpack_32_4(uint8_t* data, int8_t* dst) {
|
||||
memset(dst, 0, 16);
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[j / 2] += x;
|
||||
}
|
||||
// Last 16 weights are in the higher bits
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] >> 4);
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[8 + j / 2] += x;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 4bit weights|.
|
||||
void extract_q4_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = -8 * scales[i];
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_1 tensors.
|
||||
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
|
||||
void extract_q4_1_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = *((float16_t*)(data) + 1);
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
||||
void extract_q8_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t weights_per_block = 32;
|
||||
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
uint8_t* block_data = data + i * bytes_per_block;
|
||||
scales[i] = *((float16_t*)block_data);
|
||||
biases[i] = -128 * scales[i];
|
||||
for (int64_t j = 0; j < weights_per_block; ++j) {
|
||||
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
|
||||
// Original data is in int8_t, so we add a bias of -128 and invert the
|
||||
// first bit.
|
||||
x ^= 1 << 7;
|
||||
weights[i * weights_per_block + j] = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Drived from ggml-quants.c
|
||||
|
||||
#define QK_K 256
|
||||
|
||||
// 6-bit quantization
|
||||
// weight is represented as x = a * q
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 6.5625 bits per weight
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||
uint16_t d; // super-block scale
|
||||
} block_q6_K;
|
||||
|
||||
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
const int64_t nb = k / QK_K;
|
||||
block_q6_K *x = (block_q6_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
|
||||
const uint8_t * restrict ql = x[i].ql;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const int8_t * restrict sc = x[i].scales;
|
||||
|
||||
for (int n = 0; n < QK_K; n += 128) {
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
int is = l/16;
|
||||
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||||
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||||
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||||
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||||
y[l + 0] = d * sc[is + 0] * q1;
|
||||
y[l + 32] = d * sc[is + 2] * q2;
|
||||
y[l + 64] = d * sc[is + 4] * q3;
|
||||
y[l + 96] = d * sc[is + 6] * q4;
|
||||
}
|
||||
y += 128;
|
||||
ql += 64;
|
||||
qh += 32;
|
||||
sc += 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define K_SCALE_SIZE 12
|
||||
#define GGML_COMMON_AGGR_U
|
||||
#define GGML_COMMON_AGGR_S
|
||||
|
||||
// 4-bit quantization
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 4.5 bits per weight
|
||||
typedef struct {
|
||||
union {
|
||||
struct {
|
||||
uint16_t d; // super-block scale for quantized scales
|
||||
uint16_t dmin; // super-block scale for quantized mins
|
||||
} GGML_COMMON_AGGR_S;
|
||||
uint16_t dm;
|
||||
} GGML_COMMON_AGGR_U;
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
|
||||
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63; *m = q[j + 4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
block_q4_K *x = (block_q4_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
const int nb = k / QK_K;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = x[i].qs;
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
float16_t min = 0.0;
|
||||
memcpy(&min, &x[i].dmin, sizeof(d));
|
||||
|
||||
int is = 0;
|
||||
uint8_t sc, m;
|
||||
for (int j = 0; j < QK_K; j += 64) {
|
||||
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
||||
const float16_t d1 = d * sc; const float16_t m1 = min * m;
|
||||
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
||||
const float16_t d2 = d * sc; const float16_t m2 = min * m;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
||||
q += 32; is += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
shape := append([]C.int{}, final_shape...)
|
||||
var weights_per_byte C.int
|
||||
if dtype == 2 || dtype == 3 {
|
||||
weights_per_byte = 2
|
||||
} else if dtype == 8 {
|
||||
weights_per_byte = 1
|
||||
} else {
|
||||
return r, fmt.Errorf("unsupported tensor type %d", dtype)
|
||||
}
|
||||
|
||||
weights_per_block := C.int(32)
|
||||
if shape[len(shape)-1]%weights_per_block != 0 {
|
||||
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
|
||||
}
|
||||
|
||||
weights_shape := append([]C.int{}, shape...)
|
||||
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
|
||||
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
|
||||
for i := range weights_shape {
|
||||
w_nbytes *= weights_shape[i]
|
||||
}
|
||||
w_data := make([]byte, w_nbytes)
|
||||
cbytes := C.CBytes(w_data)
|
||||
defer C.free(cbytes)
|
||||
weights := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&weights_shape[0],
|
||||
C.int(len(weights_shape)),
|
||||
C.MLX_UINT32,
|
||||
)
|
||||
|
||||
// For scales and bias
|
||||
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
|
||||
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
|
||||
for i := range shape {
|
||||
sb_nbytes *= shape[i]
|
||||
}
|
||||
|
||||
s_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(s_data)
|
||||
defer C.free(cbytes)
|
||||
scales := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
b_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(b_data)
|
||||
defer C.free(cbytes)
|
||||
biases := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
var bits C.int
|
||||
switch dtype {
|
||||
case 2:
|
||||
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 3:
|
||||
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 8:
|
||||
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 8
|
||||
}
|
||||
groupSize := C.mlx_optional_int{value: 32, has_value: true}
|
||||
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
|
||||
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
|
||||
C.mlx_dequantize(
|
||||
&r,
|
||||
weights,
|
||||
scales,
|
||||
biases,
|
||||
groupSize,
|
||||
bitsOpt,
|
||||
nil, // TODO mode
|
||||
dtypeOpt,
|
||||
stream,
|
||||
)
|
||||
C.mlx_array_free(weights)
|
||||
C.mlx_array_free(scales)
|
||||
C.mlx_array_free(biases)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
size := 1
|
||||
for _, d := range shape {
|
||||
size *= int(d)
|
||||
}
|
||||
fdata := make([]float16.Float16, size)
|
||||
switch dtype {
|
||||
case 14:
|
||||
C.dequant_row_q6_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
|
||||
case 12:
|
||||
C.dequant_row_q4_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
default:
|
||||
return r, fmt.Errorf("unsupported K quant")
|
||||
}
|
||||
|
||||
r = C.mlx_array_new_data(
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
return r, nil
|
||||
}
|
||||
643
x/ml/device.go
@@ -1,643 +0,0 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// GPULayers is a set of layers to be allocated on a single GPU
|
||||
type GPULayers struct {
|
||||
DeviceID
|
||||
|
||||
// Layers is a set of layer indicies to load
|
||||
Layers []int
|
||||
}
|
||||
|
||||
// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
|
||||
func (g GPULayers) FirstLayer() int {
|
||||
if len(g.Layers) == 0 {
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
first := g.Layers[0]
|
||||
for i := 1; i < len(g.Layers); i++ {
|
||||
if g.Layers[i] < first {
|
||||
first = g.Layers[i]
|
||||
}
|
||||
}
|
||||
|
||||
return first
|
||||
}
|
||||
|
||||
func (g GPULayers) String() string {
|
||||
if len(g.Layers) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
slices.Sort(g.Layers)
|
||||
|
||||
contiguous := true
|
||||
base := g.Layers[0]
|
||||
for i := range g.Layers {
|
||||
if g.Layers[i] != base+i {
|
||||
contiguous = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if contiguous {
|
||||
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
|
||||
} else {
|
||||
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
|
||||
}
|
||||
}
|
||||
|
||||
// GPULayersList is a set of layer allocations across multiple GPUs
|
||||
type GPULayersList []GPULayers
|
||||
|
||||
func (l GPULayersList) Len() int { return len(l) }
|
||||
func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
|
||||
|
||||
// Sort by the ordering of the layers offloaded
|
||||
func (l GPULayersList) Less(i, j int) bool {
|
||||
li := l[i].FirstLayer()
|
||||
lj := l[j].FirstLayer()
|
||||
|
||||
return li < lj
|
||||
}
|
||||
|
||||
func (l GPULayersList) String() string {
|
||||
if l.Sum() > 0 {
|
||||
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
|
||||
} else {
|
||||
return fmt.Sprintf("%v", []GPULayers(l))
|
||||
}
|
||||
}
|
||||
|
||||
// Sum is the total number of layers assigned across all GPUs
|
||||
func (l GPULayersList) Sum() int {
|
||||
var sum int
|
||||
|
||||
for _, g := range l {
|
||||
sum += len(g.Layers)
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
var h maphash.Hash
|
||||
|
||||
// Hash is an identifier of this layer assignment
|
||||
func (l GPULayersList) Hash() uint64 {
|
||||
h.Reset()
|
||||
for _, g := range l {
|
||||
if len(g.Layers) > 0 {
|
||||
h.WriteString(g.ID + g.Library)
|
||||
for _, l := range g.Layers {
|
||||
binary.Write(&h, binary.NativeEndian, int64(l))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
||||
// the attempted memory allocation.
|
||||
type ErrNoMem struct {
|
||||
BackendMemory
|
||||
}
|
||||
|
||||
func (e ErrNoMem) Error() string {
|
||||
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
||||
}
|
||||
|
||||
// Minimal unique device identification
|
||||
type DeviceID struct {
|
||||
// ID is an identifier for the device for matching with system
|
||||
// management libraries. The ID is only unique for other devices
|
||||
// using the same Library.
|
||||
// This ID represents a "post filtered" view of the enumerated devices
|
||||
// if the ID is numeric
|
||||
ID string `json:"id"`
|
||||
|
||||
// Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
|
||||
Library string `json:"backend,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceMemory provides a breakdown of the memory needed
|
||||
// per device, such as a CPU or GPU.
|
||||
type DeviceMemory struct {
|
||||
DeviceID
|
||||
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string
|
||||
|
||||
// Weights is the per-layer memory needed for the model weights.
|
||||
Weights []uint64
|
||||
|
||||
// Cache is the per-layer memory needed for the KV cache.
|
||||
Cache []uint64
|
||||
|
||||
// Graph is the size of the compute graph. It is not per-layer.
|
||||
Graph uint64
|
||||
}
|
||||
|
||||
func sumMemory(mem []uint64) uint64 {
|
||||
var sum uint64
|
||||
|
||||
for _, m := range mem {
|
||||
sum += m
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// Size returns the total size of the memory required by this device
|
||||
func (m DeviceMemory) Size() uint64 {
|
||||
return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
|
||||
}
|
||||
|
||||
func memoryPresent(mem []uint64) bool {
|
||||
return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
|
||||
}
|
||||
|
||||
func (m DeviceMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if memoryPresent(m.Weights) {
|
||||
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
||||
}
|
||||
|
||||
if memoryPresent(m.Cache) {
|
||||
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
||||
}
|
||||
|
||||
if m.Graph != 0 {
|
||||
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||
}
|
||||
|
||||
if len(attrs) > 0 && m.ID != "" {
|
||||
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// BackendMemory provides the amount of memory required to load the model
|
||||
// per device based on the BackendParams. In some cases, not all required
|
||||
// allocations will be known at this point. However, the size of the most recent
|
||||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||
// accommodate that to make forward progress.
|
||||
type BackendMemory struct {
|
||||
// InputWeights are always located on the CPU and cannot be moved
|
||||
InputWeights uint64
|
||||
|
||||
// CPU model components are located in system memory. This does not
|
||||
// include unified memory allocated through the GPU.
|
||||
CPU DeviceMemory
|
||||
|
||||
// GPU model components are located on one or more GPUs.
|
||||
GPUs []DeviceMemory
|
||||
}
|
||||
|
||||
func (m BackendMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if m.InputWeights != 0 {
|
||||
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||
}
|
||||
|
||||
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
||||
for _, g := range m.GPUs {
|
||||
attrs = append(attrs, slog.Any(g.Name, g))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// Log prints a high level summary of the memory
|
||||
func (m BackendMemory) Log(level slog.Level) {
|
||||
var total uint64
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := sumMemory(m.CPU.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := gpu.Graph; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.CPU.Graph; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
|
||||
}
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
DeviceID
|
||||
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Description is the longer user-friendly identification of the device
|
||||
Description string `json:"description"`
|
||||
|
||||
// FilterID is populated with the unfiltered device ID if a numeric ID is used
|
||||
// so the device can be included.
|
||||
FilterID string `json:"filter_id,omitempty"`
|
||||
|
||||
// Integrated is set true for integrated GPUs, false for Discrete GPUs
|
||||
Integrated bool `json:"integration,omitempty"`
|
||||
|
||||
// PCIID is the bus, device and domain ID of the device for deduplication
|
||||
// when discovered by multiple backends
|
||||
PCIID string `json:"pci_id,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of memory the device can use for loading models
|
||||
TotalMemory uint64 `json:"total_memory"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the device for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// ComputeMajor is the major version of capabilities of the device
|
||||
// if unsupported by the backend, -1 will be returned
|
||||
ComputeMajor int
|
||||
|
||||
// ComputeMinor is the minor version of capabilities of the device
|
||||
// if unsupported by the backend, -1 will be returned
|
||||
ComputeMinor int
|
||||
|
||||
// Driver Information
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
|
||||
// Where backends were loaded from
|
||||
LibraryPath []string
|
||||
}
|
||||
|
||||
type SystemInfo struct {
|
||||
// ThreadCount is the optimal number of threads to use for inference
|
||||
ThreadCount int `json:"threads,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of system memory
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the system for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// FreeSwap is the amount of system swap space reported as available
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Compute() string {
|
||||
// AMD gfx is encoded into the major minor in hex form
|
||||
if strings.EqualFold(d.Library, "ROCm") {
|
||||
return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
|
||||
}
|
||||
return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Driver() string {
|
||||
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
|
||||
}
|
||||
|
||||
// MinimumMemory reports the amount of memory that should be set aside
|
||||
// on the device for overhead (e.g. VRAM consumed by context structures independent
|
||||
// of model allocations)
|
||||
func (d DeviceInfo) MinimumMemory() uint64 {
|
||||
if d.Library == "Metal" {
|
||||
return 512 * format.MebiByte
|
||||
}
|
||||
return 457 * format.MebiByte
|
||||
}
|
||||
|
||||
// Sort by Free Space.
|
||||
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
|
||||
type ByFreeMemory []DeviceInfo
|
||||
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool {
|
||||
if a[i].Integrated && !a[j].Integrated {
|
||||
return true
|
||||
} else if !a[i].Integrated && a[j].Integrated {
|
||||
return false
|
||||
}
|
||||
return a[i].FreeMemory < a[j].FreeMemory
|
||||
}
|
||||
|
||||
// ByPerformance groups devices by similar speed
|
||||
func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
scores := []bool{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Integrated
|
||||
for i, score := range scores {
|
||||
if score == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
scores = append(scores, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
libs := []string{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
libs = append(libs, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func LibraryPaths(l []DeviceInfo) []string {
|
||||
gpuLibs := []string{LibOllamaPath}
|
||||
for _, gpu := range l {
|
||||
for _, dir := range gpu.LibraryPath {
|
||||
needed := true
|
||||
for _, existing := range gpuLibs {
|
||||
if dir == existing {
|
||||
needed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needed {
|
||||
gpuLibs = append(gpuLibs, dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
return gpuLibs
|
||||
}
|
||||
|
||||
type DeviceComparison int
|
||||
|
||||
const (
|
||||
UniqueDevice DeviceComparison = iota
|
||||
SameBackendDevice // The device is the same, and the library/backend is the same
|
||||
DuplicateDevice // The same physical device but different library/backend (overlapping device)
|
||||
)
|
||||
|
||||
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||
if a.PCIID != b.PCIID {
|
||||
return UniqueDevice
|
||||
}
|
||||
// If PCIID is empty, we have to use ID + library for uniqueness
|
||||
if a.PCIID == "" && a.DeviceID != b.DeviceID {
|
||||
return UniqueDevice
|
||||
}
|
||||
if a.Library == b.Library {
|
||||
return SameBackendDevice
|
||||
}
|
||||
return DuplicateDevice
|
||||
}
|
||||
|
||||
// For a SameBackendDevice, return true if b is better than a
|
||||
// e.g. newer GPU library version
|
||||
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||
aLib := a.LibraryPath[len(a.LibraryPath)-1]
|
||||
bLib := b.LibraryPath[len(b.LibraryPath)-1]
|
||||
if aLib == bLib {
|
||||
return false
|
||||
}
|
||||
aLibSplit := strings.SplitN(aLib, "_", 2)
|
||||
bLibSplit := strings.SplitN(bLib, "_", 2)
|
||||
if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
|
||||
return false
|
||||
}
|
||||
if aLibSplit[0] != bLibSplit[0] {
|
||||
slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
|
||||
return false
|
||||
}
|
||||
if aLibSplit[1] == bLibSplit[1] {
|
||||
return false
|
||||
}
|
||||
cmp := []string{aLibSplit[1], bLibSplit[1]}
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
|
||||
return cmp[0] == bLibSplit[1]
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
|
||||
gpu.Library == "ROCm" ||
|
||||
gpu.Library == "Vulkan"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variables
|
||||
// Set mustFilter true to enable filtering of CUDA devices
|
||||
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
env := map[string]string{}
|
||||
for _, d := range l {
|
||||
d.updateVisibleDevicesEnv(env, mustFilter)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// NeedsInitValidation returns true if the device in question has the potential
|
||||
// to crash at inference time and requires deeper validation before we include
|
||||
// it in the supported devices list.
|
||||
func (d DeviceInfo) NeedsInitValidation() bool {
|
||||
// ROCm: rocblas will crash on unsupported devices.
|
||||
// CUDA: verify CC is supported by the version of the library
|
||||
return d.Library == "ROCm" || d.Library == "CUDA"
|
||||
}
|
||||
|
||||
// Set the init validation environment variable
|
||||
func (d DeviceInfo) AddInitValidation(env map[string]string) {
|
||||
env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
|
||||
}
|
||||
|
||||
// PreferredLibrary returns true if this library is preferred over the other input
|
||||
// library
|
||||
// Used to filter out Vulkan in favor of CUDA or ROCm
|
||||
func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
|
||||
// TODO in the future if we find Vulkan is better than ROCm on some devices
|
||||
// that implementation can live here.
|
||||
|
||||
if d.Library == "CUDA" || d.Library == "ROCm" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
|
||||
var envVar string
|
||||
switch d.Library {
|
||||
case "ROCm":
|
||||
// ROCm must be filtered as it can crash the runner on unsupported devices
|
||||
envVar = "ROCR_VISIBLE_DEVICES"
|
||||
if runtime.GOOS != "linux" {
|
||||
envVar = "HIP_VISIBLE_DEVICES"
|
||||
}
|
||||
case "CUDA":
|
||||
if !mustFilter {
|
||||
// By default we try to avoid filtering CUDA devices because ROCm also
|
||||
// looks at the CUDA env var, and gets confused in mixed vendor environments.
|
||||
return
|
||||
}
|
||||
envVar = "CUDA_VISIBLE_DEVICES"
|
||||
default:
|
||||
// Vulkan is not filtered via env var, but via scheduling decisions
|
||||
return
|
||||
}
|
||||
v, existing := env[envVar]
|
||||
if existing {
|
||||
v = v + ","
|
||||
}
|
||||
if d.FilterID != "" {
|
||||
v = v + d.FilterID
|
||||
} else {
|
||||
v = v + d.ID
|
||||
}
|
||||
env[envVar] = v
|
||||
}
|
||||
|
||||
type BaseRunner interface {
|
||||
// GetPort returns the localhost port number the runner is running on
|
||||
GetPort() int
|
||||
|
||||
// HasExited indicates if the runner is no longer running. This can be used during
|
||||
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
|
||||
HasExited() bool
|
||||
}
|
||||
|
||||
type RunnerDiscovery interface {
|
||||
BaseRunner
|
||||
|
||||
// GetDeviceInfos will perform a query of the underlying device libraries
|
||||
// for device identification and free VRAM information
|
||||
// During bootstrap scenarios, this routine may take seconds to complete
|
||||
GetDeviceInfos(ctx context.Context) []DeviceInfo
|
||||
}
|
||||
|
||||
type FilteredRunnerDiscovery interface {
|
||||
RunnerDiscovery
|
||||
|
||||
// GetActiveDeviceIDs returns the filtered set of devices actively in
|
||||
// use by this runner for running models. If the runner is a bootstrap runner, no devices
|
||||
// will be active yet so no device IDs are returned.
|
||||
// This routine will not query the underlying device and will return immediately
|
||||
GetActiveDeviceIDs() []DeviceID
|
||||
}
|
||||
|
||||
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
|
||||
var moreDevices []DeviceInfo
|
||||
port := runner.GetPort()
|
||||
tick := time.Tick(10 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
||||
case <-tick:
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
// slog.Warn("failed to send request", "error", err)
|
||||
if runner.HasExited() {
|
||||
return nil, fmt.Errorf("runner crashed")
|
||||
}
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// old runner, fall back to bootstrapping model
|
||||
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
slog.Warn("failed to read response", "error", err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
|
||||
return nil, fmt.Errorf("runner error: %s", string(body))
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &moreDevices); err != nil {
|
||||
slog.Warn("unmarshal encode response", "error", err)
|
||||
continue
|
||||
}
|
||||
return moreDevices, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/kvcache"
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
} else if cache == nil {
|
||||
panic("key & value tensors must be provided if cache is nil")
|
||||
}
|
||||
|
||||
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
|
||||
// panic("after cache get") //
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
|
||||
// 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
|
||||
|
||||
// var mask ml.Tensor
|
||||
if cache != nil {
|
||||
key, value, _ = cache.Get(ctx)
|
||||
}
|
||||
// ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
|
||||
// panic("after cache get") //
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
|
||||
// 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
|
||||
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
|
||||
if cache != nil {
|
||||
// TODO what to do with vmla?
|
||||
// return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
|
||||
return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
|
||||
|
||||
// TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
|
||||
} else {
|
||||
panic("else case not supported")
|
||||
// TODO transpose shapes are wrong
|
||||
// key = key.Transpose(ctx, 0, 2, 1, 3)
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
|
||||
|
||||
// kq := query.Matmul(ctx, key)
|
||||
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
// kq = kq.Softmax(ctx)
|
||||
|
||||
// kqv := kq.Matmul(ctx, value)
|
||||
|
||||
// if vmla != nil {
|
||||
// kqv = kqv.Matmul(ctx, vmla)
|
||||
// }
|
||||
|
||||
// return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Conv2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
|
||||
if m.Bias != nil {
|
||||
// Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
|
||||
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type Conv3D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
|
||||
t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Embedding struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
return m.Weight.TakeAxes(ctx, hiddenState, 0)
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
type Linear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
t = t.Matmul(ctx, m.Weight.Transpose(ctx))
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type LinearBatch struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
|
||||
panic("not yet ported")
|
||||
// t = m.Weight.MulmatID(ctx, t, indices)
|
||||
// if m.Bias != nil {
|
||||
// t = t.AddID(ctx, m.Bias, indices)
|
||||
// }
|
||||
|
||||
// return t
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
// slog.Info("RMSNorm", "eps", eps)
|
||||
// fmt.Fprintln(os.Stderr, t.ToString())
|
||||
// fmt.Fprintln(os.Stderr, m.Weight.ToString())
|
||||
|
||||
// TODO this is probably model specific, not generalized...
|
||||
w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
|
||||
|
||||
return t.RMSNorm(ctx, w, eps)
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package pooling
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeNone Type = iota
|
||||
TypeMean
|
||||
TypeCLS
|
||||
TypeLast
|
||||
)
|
||||
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
return "Mean"
|
||||
case TypeCLS:
|
||||
return "CLS"
|
||||
case TypeLast:
|
||||
return "Last"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
switch t {
|
||||
// case TypeMean:
|
||||
// hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
|
||||
// return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
|
||||
// case TypeCLS:
|
||||
// return hiddenStates.Slice(ctx, 1, 0, 1, 1)
|
||||
// case TypeLast:
|
||||
// return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
|
||||
default:
|
||||
panic("unknown pooling type")
|
||||
}
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package rope
|
||||
|
||||
import "github.com/ollama/ollama/x/ml"
|
||||
|
||||
// Options contains optional parameters for RoPE function
|
||||
type Options struct {
|
||||
Type int
|
||||
Factors ml.Tensor
|
||||
|
||||
// YaRN options
|
||||
YaRN struct {
|
||||
OriginalContextLength int
|
||||
ExtrapolationFactor,
|
||||
AttentionFactor,
|
||||
BetaFast,
|
||||
BetaSlow float32
|
||||
}
|
||||
|
||||
// MRoPE options
|
||||
MRoPE struct {
|
||||
Sections []int
|
||||
}
|
||||
}
|
||||
|
||||
// WithTypeNeoX sets RoPE type to NeoX
|
||||
func WithTypeNeoX() func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type = 2
|
||||
}
|
||||
}
|
||||
|
||||
// WithFactors sets custom rope factors
|
||||
func WithFactors(factors ml.Tensor) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
if factors != nil {
|
||||
opts.Factors = factors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithOriginalContextLength sets a custom context length
|
||||
func WithOriginalContextLength(n int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.OriginalContextLength = n
|
||||
}
|
||||
}
|
||||
|
||||
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.ExtrapolationFactor = extrapolationFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithAttentionFactor(attentionFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.AttentionFactor = attentionFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1 << 3
|
||||
opts.MRoPE.Sections = sections
|
||||
}
|
||||
}
|
||||
|
||||
func WithInterleaveMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1<<3 | 1<<5
|
||||
opts.MRoPE.Sections = sections
|
||||
}
|
||||
}
|
||||