Compare commits

..

7 Commits

Author SHA1 Message Date
Gyungrai Wang
55d0b6e8b9 integration: fix tools_test.go for ToolCallFunctionArguments API change (#13731) 2026-01-15 16:08:09 -08:00
Devon Rifkin
38eac40d56 openai: tweak v1/responses to conform better (#13736)
* openai: tweak v1/responses to conform better

* openai: provide better error for image URLs

* lint
2026-01-15 15:46:36 -08:00
Jeffrey Morgan
80f3f1bc25 readme: add instructions to build with MLX (#13733) 2026-01-15 11:03:52 -08:00
Parth Sareen
b1a0db547b docs: add env var needed for claude code in docs (#13721) 2026-01-15 10:11:00 -08:00
Parth Sareen
75d7b5f926 cmd: enable multi-line input and shift enter (#13694) 2026-01-14 17:52:46 -08:00
vincent d warmerdam
349d814814 docs: add marimo integration (#13326)
* docs added

* fix title

* add marimo to docs.json

---------

Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:37:38 -08:00
Yuhong Sun
c8743031e0 docs: add onyx integration (#13135)
* Ready for team review

* Update docs/integrations/onyx.mdx

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* update docs.json

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:32:05 -08:00
37 changed files with 815 additions and 2769 deletions

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2
```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
```shell
go build -tags mlx -o ollama-mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API
Ollama has a REST API for running and managing models.
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
@@ -421,7 +454,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -46,9 +46,8 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -94,82 +93,15 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
modelName := args[0]
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) || filename == "" {
// No Modelfile specified or found - use default
reader = strings.NewReader("FROM .\n")
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
reader = f
}
// Parse the Modelfile
modelfile, err := parser.ParseFile(reader)
if err != nil {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
}, p)
}
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if create.IsTensorModelDir(".") {
if imagegen.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: args[0],
ModelDir: ".",
Quantize: quantize,
}, p)
return imagegenclient.CreateModel(args[0], ".", quantize, p)
}
reader = strings.NewReader("FROM .\n")
} else {
@@ -1810,22 +1742,15 @@ func NewCLI() *cobra.Command {
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
// Skip server check for experimental mode (writes directly to disk)
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: CreateHandler,
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{
Use: "show MODEL",

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err

View File

@@ -21,6 +21,7 @@ ollama pull glm-4.7:cloud
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
@@ -247,12 +248,13 @@ curl -X POST http://localhost:11434/v1/messages \
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```

View File

@@ -111,7 +111,9 @@
"/integrations/zed",
"/integrations/roo-code",
"/integrations/n8n",
"/integrations/xcode"
"/integrations/xcode",
"/integrations/onyx",
"/integrations/marimo"
]
},
{

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

BIN
docs/images/marimo-chat.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

BIN
docs/images/onyx-login.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 300 KiB

BIN
docs/images/onyx-query.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -25,6 +25,7 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
@@ -38,7 +39,7 @@ claude --model qwen3-coder
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com

View File

@@ -0,0 +1,73 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -0,0 +1,63 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():

View File

@@ -8,6 +8,7 @@ import (
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -441,6 +442,7 @@ type ResponsesWriter struct {
stream bool
responseID string
itemID string
request openai.ResponsesRequest
}
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
@@ -478,7 +480,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
// Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
completedAt := time.Now().Unix()
response.CompletedAt = &completedAt
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
@@ -523,11 +527,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
model: req.Model,
stream: streamRequested,
responseID: responseID,
itemID: itemID,
request: req,
}
// Set headers based on streaming mode

View File

@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
// decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
}
types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"time"
"github.com/ollama/ollama/api"
)
@@ -265,9 +266,9 @@ type ResponsesText struct {
type ResponsesTool struct {
Type string `json:"type"` // "function"
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
Description *string `json:"description"` // nullable but required
Strict *bool `json:"strict"` // nullable but required
Parameters map[string]any `json:"parameters"` // nullable but required
}
type ResponsesRequest struct {
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
}
}
var description string
if t.Description != nil {
description = *t.Description
}
return api.Tool{
Type: t.Type,
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Description: description,
Parameters: params,
},
}, nil
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
// Response types for the Responses API
// ResponsesTextField represents the text output configuration in the response.
type ResponsesTextField struct {
Format ResponsesTextFormat `json:"format"`
}
// ResponsesReasoningOutput represents reasoning configuration in the response.
type ResponsesReasoningOutput struct {
Effort *string `json:"effort,omitempty"`
Summary *string `json:"summary,omitempty"`
}
// ResponsesError represents an error in the response.
type ResponsesError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// ResponsesIncompleteDetails represents details about why a response was incomplete.
type ResponsesIncompleteDetails struct {
Reason string `json:"reason"`
}
type ResponsesResponse struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Status string `json:"status"`
Model string `json:"model"`
Output []ResponsesOutputItem `json:"output"`
Usage *ResponsesUsage `json:"usage,omitempty"`
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
// requires additional plumbing to find the effective values since the
// defaults can come from the model or the request
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
CompletedAt *int64 `json:"completed_at"`
Status string `json:"status"`
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
Model string `json:"model"`
PreviousResponseID *string `json:"previous_response_id"`
Instructions *string `json:"instructions"`
Output []ResponsesOutputItem `json:"output"`
Error *ResponsesError `json:"error"`
Tools []ResponsesTool `json:"tools"`
ToolChoice any `json:"tool_choice"`
Truncation string `json:"truncation"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
Text ResponsesTextField `json:"text"`
TopP float64 `json:"top_p"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
TopLogprobs int `json:"top_logprobs"`
Temperature float64 `json:"temperature"`
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
Usage *ResponsesUsage `json:"usage"`
MaxOutputTokens *int `json:"max_output_tokens"`
MaxToolCalls *int `json:"max_tool_calls"`
Store bool `json:"store"`
Background bool `json:"background"`
ServiceTier string `json:"service_tier"`
Metadata map[string]any `json:"metadata"`
SafetyIdentifier *string `json:"safety_identifier"`
PromptCacheKey *string `json:"prompt_cache_key"`
}
type ResponsesOutputItem struct {
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
}
type ResponsesOutputContent struct {
Type string `json:"type"` // "output_text"
Text string `json:"text"`
Type string `json:"type"` // "output_text"
Text string `json:"text"`
Annotations []any `json:"annotations"`
Logprobs []any `json:"logprobs"`
}
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
}
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
}
type ResponsesUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
}
// ToResponse converts an api.ChatResponse to a Responses API response
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
func derefFloat64(p *float64, def float64) float64 {
if p != nil {
return *p
}
return def
}
// ToResponse converts an api.ChatResponse to a Responses API response.
// The request is used to echo back request parameters in the response.
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
var output []ResponsesOutputItem
// Add reasoning item if thinking is present
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
output = append(output, ResponsesOutputItem{
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
Type: "function_call",
Status: "completed",
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
Role: "assistant",
Content: []ResponsesOutputContent{
{
Type: "output_text",
Text: chatResponse.Message.Content,
Type: "output_text",
Text: chatResponse.Message.Content,
Annotations: []any{},
Logprobs: []any{},
},
},
})
}
var instructions *string
if request.Instructions != "" {
instructions = &request.Instructions
}
// Build truncation with default
truncation := "disabled"
if request.Truncation != nil {
truncation = *request.Truncation
}
tools := request.Tools
if tools == nil {
tools = []ResponsesTool{}
}
text := ResponsesTextField{
Format: ResponsesTextFormat{Type: "text"},
}
if request.Text != nil && request.Text.Format != nil {
text.Format = *request.Text.Format
}
// Build reasoning output from request
var reasoning *ResponsesReasoningOutput
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
reasoning = &ResponsesReasoningOutput{}
if request.Reasoning.Effort != "" {
reasoning.Effort = &request.Reasoning.Effort
}
if request.Reasoning.Summary != "" {
reasoning.Summary = &request.Reasoning.Summary
}
}
return ResponsesResponse{
ID: responseID,
Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(),
Status: "completed",
Model: model,
Output: output,
ID: responseID,
Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(),
CompletedAt: nil, // Set by middleware when writing final response
Status: "completed",
IncompleteDetails: nil, // Only populated if response incomplete
Model: model,
PreviousResponseID: nil, // Not supported
Instructions: instructions,
Output: output,
Error: nil, // Only populated on failure
Tools: tools,
ToolChoice: "auto", // Default value
Truncation: truncation,
ParallelToolCalls: true, // Default value
Text: text,
TopP: derefFloat64(request.TopP, 1.0),
PresencePenalty: 0, // Default value
FrequencyPenalty: 0, // Default value
TopLogprobs: 0, // Default value
Temperature: derefFloat64(request.Temperature, 1.0),
Reasoning: reasoning,
Usage: &ResponsesUsage{
InputTokens: chatResponse.PromptEvalCount,
OutputTokens: chatResponse.EvalCount,
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
// TODO(drifkin): wire through the actual values
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
// TODO(drifkin): wire through the actual values
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
},
MaxOutputTokens: request.MaxOutputTokens,
MaxToolCalls: nil, // Not supported
Store: false, // We don't store responses
Background: request.Background,
ServiceTier: "default", // Default value
Metadata: map[string]any{},
SafetyIdentifier: nil, // Not supported
PromptCacheKey: nil, // Not supported
}
}
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
responseID string
itemID string
model string
request ResponsesRequest
// State tracking (mutated across Process calls)
firstWrite bool
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
}
// NewResponsesStreamConverter creates a new converter with the given configuration.
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
return &ResponsesStreamConverter{
responseID: responseID,
itemID: itemID,
model: model,
request: request,
firstWrite: true,
}
}
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
return events
}
// buildResponseObject creates a full response object with all required fields for streaming events.
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
var instructions any = nil
if c.request.Instructions != "" {
instructions = c.request.Instructions
}
truncation := "disabled"
if c.request.Truncation != nil {
truncation = *c.request.Truncation
}
var tools []any
if c.request.Tools != nil {
for _, t := range c.request.Tools {
tools = append(tools, map[string]any{
"type": t.Type,
"name": t.Name,
"description": t.Description,
"strict": t.Strict,
"parameters": t.Parameters,
})
}
}
if tools == nil {
tools = []any{}
}
textFormat := map[string]any{"type": "text"}
if c.request.Text != nil && c.request.Text.Format != nil {
textFormat = map[string]any{
"type": c.request.Text.Format.Type,
}
if c.request.Text.Format.Name != "" {
textFormat["name"] = c.request.Text.Format.Name
}
if c.request.Text.Format.Schema != nil {
textFormat["schema"] = c.request.Text.Format.Schema
}
if c.request.Text.Format.Strict != nil {
textFormat["strict"] = *c.request.Text.Format.Strict
}
}
var reasoning any = nil
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
r := map[string]any{}
if c.request.Reasoning.Effort != "" {
r["effort"] = c.request.Reasoning.Effort
} else {
r["effort"] = nil
}
if c.request.Reasoning.Summary != "" {
r["summary"] = c.request.Reasoning.Summary
} else {
r["summary"] = nil
}
reasoning = r
}
// Build top_p and temperature with defaults
topP := 1.0
if c.request.TopP != nil {
topP = *c.request.TopP
}
temperature := 1.0
if c.request.Temperature != nil {
temperature = *c.request.Temperature
}
return map[string]any{
"id": c.responseID,
"object": "response",
"created_at": time.Now().Unix(),
"completed_at": nil,
"status": status,
"incomplete_details": nil,
"model": c.model,
"previous_response_id": nil,
"instructions": instructions,
"output": output,
"error": nil,
"tools": tools,
"tool_choice": "auto",
"truncation": truncation,
"parallel_tool_calls": true,
"text": map[string]any{"format": textFormat},
"top_p": topP,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_logprobs": 0,
"temperature": temperature,
"reasoning": reasoning,
"usage": usage,
"max_output_tokens": c.request.MaxOutputTokens,
"max_tool_calls": nil,
"store": false,
"background": c.request.Background,
"service_tier": "default",
"metadata": map[string]any{},
"safety_identifier": nil,
"prompt_cache_key": nil,
}
}
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
return c.newEvent("response.created", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
"response": c.buildResponseObject("in_progress", []any{}, nil),
})
}
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
return c.newEvent("response.in_progress", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
"response": c.buildResponseObject("in_progress", []any{}, nil),
})
}
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
// Emit delta
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"delta": thinking,
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"summary_index": 0,
"delta": thinking,
}))
// TODO(drifkin): consider adding
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
events := []ResponsesStreamEvent{
c.newEvent("response.reasoning_summary_text.done", map[string]any{
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"text": c.accumulatedThinking,
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"summary_index": 0,
"text": c.accumulatedThinking,
}),
c.newEvent("response.output_item.done", map[string]any{
"output_index": c.outputIndex,
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex,
"content_index": c.contentIndex,
"part": map[string]any{
"type": "output_text",
"text": "",
"type": "output_text",
"text": "",
"annotations": []any{},
"logprobs": []any{},
},
}))
}
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex,
"content_index": 0,
"delta": content,
"logprobs": []any{},
}))
return events
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
"status": "completed",
"role": "assistant",
"content": []map[string]any{{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}},
})
}
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex,
"content_index": 0,
"text": c.accumulatedText,
"logprobs": []any{},
}))
// response.content_part.done
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex,
"content_index": 0,
"part": map[string]any{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
},
}))
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"status": "completed",
"role": "assistant",
"content": []map[string]any{{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}},
},
}))
}
// response.completed
events = append(events, c.newEvent("response.completed", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "completed",
"output": c.buildFinalOutput(),
"usage": map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
},
usage := map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
"output_tokens_details": map[string]any{
"reasoning_tokens": 0,
},
}
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
response["completed_at"] = time.Now().Unix()
events = append(events, c.newEvent("response.completed", map[string]any{
"response": response,
}))
return events

View File

@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
}
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk with content
events := converter.Process(api.ChatResponse{
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
}
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
}
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk with thinking
events := converter.Process(api.ChatResponse{
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
Content: "The answer is 42",
},
Done: true,
})
}, ResponsesRequest{})
// Should have 2 output items: reasoning + message
if len(response.Output) != 2 {
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
// Verify that response.output_item.done includes content field for messages
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk
converter.Process(api.ChatResponse{
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
// Verify that response.completed includes the output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// Process some content
converter.Process(api.ChatResponse{
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
// Verify that response.created includes an empty output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hi"},
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
// Verify that events include incrementing sequence numbers
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hello"},
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
// Verify that function call items include status field
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"strings"
)
type Prompt struct {
@@ -36,10 +37,11 @@ type Terminal struct {
}
type Instance struct {
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
pastedLines []string
}
func New(prompt Prompt) (*Instance, error) {
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
case CharEsc:
esc = true
case CharInterrupt:
i.pastedLines = nil
i.Prompt.UseAlt = false
return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
case CharForward:
buf.MoveRight()
case CharBackspace, CharCtrlH:
buf.Remove()
if buf.IsEmpty() && len(i.pastedLines) > 0 {
lastIdx := len(i.pastedLines) - 1
prevLine := i.pastedLines[lastIdx]
i.pastedLines = i.pastedLines[:lastIdx]
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
if len(i.pastedLines) == 0 {
fmt.Print(i.Prompt.Prompt)
i.Prompt.UseAlt = false
} else {
fmt.Print(i.Prompt.AltPrompt)
}
for _, r := range prevLine {
buf.Add(r)
}
} else {
buf.Remove()
}
case CharTab:
// todo: convert back to real tabs
for range 8 {
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ:
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter, CharCtrlJ:
case CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
case CharEnter:
output := buf.String()
if len(i.pastedLines) > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
if output != "" {
i.History.Add(output)
}
buf.MoveToEnd()
fmt.Println()
i.Prompt.UseAlt = false
return output, nil
default:

View File

@@ -52,7 +52,6 @@ import (
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
xserver "github.com/ollama/ollama/x/server"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -1134,22 +1133,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
if req.System != "" {
m.System = req.System
}
@@ -1236,20 +1219,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err

View File

@@ -1,50 +0,0 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
### Building ollama-mlx
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
#### macOS (Apple Silicon and Intel)
```bash
# Build MLX backend libraries
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
go build -tags mlx -o ollama-mlx .
```
#### Linux (CUDA)
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
```bash
# Build MLX backend libraries with CUDA support
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
go build -tags mlx -o ollama-mlx .
```
#### Using build scripts
The build scripts automatically create the `ollama-mlx` binary:
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
## Image Generation
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.

View File

@@ -25,14 +25,6 @@ import (
"github.com/ollama/ollama/x/tools"
)
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
@@ -656,7 +648,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err
@@ -707,7 +699,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var sb strings.Builder
var format string
var system string
var multiline MultilineState = MultilineNone
for {
line, err := scanner.Readline()
@@ -721,37 +712,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
scanner.Prompt.UseAlt = false
sb.Reset()
multiline = MultilineNone
continue
case err != nil:
return err
}
switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/clear"):
@@ -860,41 +826,18 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
fmt.Println("Usage: /set system <message>")
continue
}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
system = strings.Join(args[2:], " ")
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
@@ -1081,7 +1024,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line)
}
if sb.Len() > 0 && multiline == MultilineNone {
if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)

View File

@@ -1,282 +0,0 @@
// Package client provides client-side model creation for safetensors-based models.
//
// This package is in x/ because the safetensors model storage format is under development.
// It also exists to break an import cycle: server imports x/create, so x/create
// cannot import server. This sub-package can import server because server doesn't
// import it.
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
const MinOllamaVersion = "0.14.0"
// ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct {
Template string
System string
License string
}
// CreateOptions holds all options for model creation.
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "fp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
// CreateModel imports a model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Detect model type
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
isImageGen := create.IsTensorModelDir(opts.ModelDir)
if !isSafetensors && !isImageGen {
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
}
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
capabilities = []string{"image"}
}
// Set up progress spinner
statusMsg := "importing " + modelType
spinner := progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
progressFn := func(msg string) {
spinner.Stop()
statusMsg = msg
spinner = progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
}
// Create the model using shared callbacks
var err error
if isSafetensors {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
}
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
return nil
}
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
return create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
}
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
// When doQuantize is true, returns multiple layers (weight + scales + optional qbias).
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
return func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]create.LayerInfo, error) {
if doQuantize {
return createQuantizedLayers(r, name, dtype, shape)
}
return createUnquantizedLayer(r, name)
}
}
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32) ([]create.LayerInfo, error) {
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []create.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name,
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale",
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, create.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias",
})
}
return layers, nil
}
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []create.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: capabilities,
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
// Add Modelfile layers if present
if opts.Modelfile != nil {
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
if err != nil {
return err
}
serverLayers = append(serverLayers, modelfileLayers...)
}
return server.WriteManifest(name, configLayer, serverLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
if mf.Template != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
layers = append(layers, layer)
}
if mf.System != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
layers = append(layers, layer)
}
if mf.License != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}
layers = append(layers, layer)
}
return layers, nil
}

View File

@@ -1,146 +0,0 @@
package client
import (
"testing"
)
func TestModelfileConfig(t *testing.T) {
// Test that ModelfileConfig struct works as expected
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
}
if config.Template != "{{ .Prompt }}" {
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
}
if config.System != "You are a helpful assistant." {
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
}
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
}
func TestModelfileConfig_Empty(t *testing.T) {
config := &ModelfileConfig{}
if config.Template != "" {
t.Errorf("Template should be empty, got %q", config.Template)
}
if config.System != "" {
t.Errorf("System should be empty, got %q", config.System)
}
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
}
func TestModelfileConfig_PartialFields(t *testing.T) {
// Test config with only some fields set
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
// System and License intentionally empty
}
if config.Template == "" {
t.Error("Template should not be empty")
}
if config.System != "" {
t.Error("System should be empty")
}
if config.License != "" {
t.Error("License should be empty")
}
}
func TestMinOllamaVersion(t *testing.T) {
// Verify the minimum version constant is set
if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty")
}
if MinOllamaVersion != "0.14.0" {
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
}
}
func TestCreateModel_InvalidDir(t *testing.T) {
// Test that CreateModel returns error for invalid directory
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: "/nonexistent/path",
}, nil)
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
// Test that CreateModel returns error for directory without safetensors
dir := t.TempDir()
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: dir,
}, nil)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateOptions(t *testing.T) {
opts := CreateOptions{
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
Modelfile: &ModelfileConfig{
Template: "test",
System: "system",
License: "MIT",
},
}
if opts.ModelName != "my-model" {
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
}
if opts.ModelDir != "/path/to/model" {
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
}
if opts.Quantize != "fp8" {
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
}
if opts.Modelfile == nil {
t.Error("Modelfile should not be nil")
}
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
}
func TestCreateOptions_Defaults(t *testing.T) {
opts := CreateOptions{
ModelName: "test",
ModelDir: "/tmp",
}
// Quantize should default to empty
if opts.Quantize != "" {
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
}
// Modelfile should default to nil
if opts.Modelfile != nil {
t.Error("Modelfile should be nil by default")
}
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)
supported := QuantizeSupported()
// In non-mlx builds, this should be false
// We can't easily test both cases, so just verify it returns something
_ = supported
}

View File

@@ -1,391 +0,0 @@
package create
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// ModelConfig represents the config blob stored with a model.
type ModelConfig struct {
ModelFormat string `json:"model_format"`
Capabilities []string `json:"capabilities"`
}
// Manifest represents the manifest JSON structure.
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config ManifestLayer `json:"config"`
Layers []ManifestLayer `json:"layers"`
}
// ManifestLayer represents a layer in the manifest.
type ManifestLayer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Name string `json:"name,omitempty"`
}
// defaultManifestDir returns the manifest storage directory.
func defaultManifestDir() string {
return filepath.Join(envconfig.Models(), "manifests")
}
// defaultBlobDir returns the blob storage directory.
func defaultBlobDir() string {
return filepath.Join(envconfig.Models(), "blobs")
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
}
// loadManifest loads a manifest for the given model name.
func loadManifest(modelName string) (*Manifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, err
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, err
}
return &manifest, nil
}
// loadModelConfig loads the config blob for a model.
func loadModelConfig(modelName string) (*ModelConfig, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return nil, err
}
// Read the config blob
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return nil, err
}
var config ModelConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, err
}
return &config, nil
}
// IsSafetensorsModel checks if a model was created with the experimental
// safetensors builder by checking the model format in the config.
func IsSafetensorsModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors"
}
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
// (has completion capability, not image generation).
func IsSafetensorsLLMModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
}
// IsImageGenModel checks if a model is an image generation model
// (has image capability).
func IsImageGenModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
}
// GetModelArchitecture returns the architecture from the model's config.json layer.
func GetModelArchitecture(modelName string) (string, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return "", err
}
// Find the config.json layer
for _, layer := range manifest.Layers {
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
blobName := strings.Replace(layer.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return "", err
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", err
}
// Prefer model_type, fall back to first architecture
if cfg.ModelType != "" {
return cfg.ModelType, nil
}
if len(cfg.Architectures) > 0 {
return cfg.Architectures[0], nil
}
}
}
return "", fmt.Errorf("architecture not found in model config")
}
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
// by looking for config.json and at least one .safetensors file.
func IsSafetensorsModelDir(dir string) bool {
// Must have config.json
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
return false
}
// Must have at least one .safetensors file
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, entry := range entries {
if strings.HasSuffix(entry.Name(), ".safetensors") {
return true
}
}
return false
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// ShouldQuantize returns true if a tensor should be quantized.
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
func ShouldQuantize(name, component string) bool {
// Image gen specific: skip VAE entirely
if component == "vae" {
return false
}
// Skip embeddings
if strings.Contains(name, "embed") {
return false
}
// Skip layer norms and RMS norms
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
return false
}
// Skip biases
if strings.HasSuffix(name, ".bias") {
return false
}
// Only quantize weights
return strings.HasSuffix(name, ".weight")
}
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
// This is a more detailed check that also considers tensor dimensions.
func ShouldQuantizeTensor(name string, shape []int32) bool {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return false
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return false
}
return true
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
entries, err := os.ReadDir(modelDir)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
}
// Process all safetensors files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(modelDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Determine if this tensor should be quantized
doQuantize := quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape)
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, doQuantize)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
extractor.Close()
}
// Process all JSON config files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
// Skip the index file as we don't need it after extraction
if entry.Name() == "model.safetensors.index.json" {
continue
}
cfgPath := entry.Name()
fullPath := filepath.Join(modelDir, cfgPath)
fn(fmt.Sprintf("importing config %s", cfgPath))
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
f.Close()
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use config.json as the config layer
if cfgPath == "config.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("config.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

View File

@@ -1,752 +0,0 @@
package create
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestIsTensorModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid diffusers model with model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0644)
},
expected: true,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "directory with other files but no model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsTensorModelDir(dir)
if got != tt.expected {
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid safetensors model with config.json and .safetensors file",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0644)
},
expected: true,
},
{
name: "config.json only, no safetensors files",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)
},
expected: false,
},
{
name: "safetensors file only, no config.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0644)
},
expected: false,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "multiple safetensors files with config.json",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644); err != nil {
return err
}
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0644)
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsSafetensorsModelDir(dir)
if got != tt.expected {
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
if got != false {
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
}
}
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
func createMinimalSafetensors(t *testing.T, path string) {
t.Helper()
// Create a minimal safetensors file with a single float32 tensor
header := map[string]interface{}{
"test_tensor": map[string]interface{}{
"dtype": "F32",
"shape": []int{2, 2},
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
},
}
headerJSON, err := json.Marshal(header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Write file
f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file: %v", err)
}
defer f.Close()
// Write header size (8 bytes, little endian)
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
// Write header
if _, err := f.Write(headerJSON); err != nil {
t.Fatalf("failed to write header: %v", err)
}
// Write tensor data (16 bytes of zeros for 4 float32 values)
if _, err := f.Write(make([]byte, 16)); err != nil {
t.Fatalf("failed to write tensor data: %v", err)
}
}
func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Track what was created
var createdLayers []LayerInfo
var manifestWritten bool
var manifestModelName string
var manifestConfigLayer LayerInfo
var manifestLayers []LayerInfo
var statusMessages []string
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
layer := LayerInfo{
Digest: "sha256:test",
Size: int64(len(data)),
MediaType: mediaType,
Name: name,
}
createdLayers = append(createdLayers, layer)
return layer, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
layer := LayerInfo{
Digest: "sha256:tensor_" + name,
Size: int64(len(data)),
MediaType: "application/vnd.ollama.image.tensor",
Name: name,
}
createdLayers = append(createdLayers, layer)
return []LayerInfo{layer}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
manifestConfigLayer = config
manifestLayers = layers
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
// Run CreateSafetensorsModel
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify manifest was written
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-model" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
}
// Verify config layer was set
if manifestConfigLayer.Name != "config.json" {
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
}
// Verify we have at least one tensor and one config layer
hasTensor := false
hasConfig := false
for _, layer := range manifestLayers {
if layer.Name == "test_tensor" {
hasTensor = true
}
if layer.Name == "config.json" {
hasConfig = true
}
}
if !hasTensor {
t.Error("no tensor layer found in manifest")
}
if !hasConfig {
t.Error("no config layer found in manifest")
}
// Verify status messages were sent
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
dir := t.TempDir()
// Create only a safetensors file, no config.json
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Mock callbacks (minimal)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing config.json, got nil")
}
}
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
dir := t.TempDir()
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
return []LayerInfo{{}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
dir := t.TempDir()
// Create config.json
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create model.safetensors.index.json (should be skipped)
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0644); err != nil {
t.Fatalf("failed to write index.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var configNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
configNames = append(configNames, name)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify model.safetensors.index.json was not included
for _, name := range configNames {
if name == "model.safetensors.index.json" {
t.Error("model.safetensors.index.json should have been skipped")
}
}
}
func TestResolveManifestPath(t *testing.T) {
tests := []struct {
name string
modelName string
wantParts []string // Parts that should appear in the path
}{
{
name: "simple model name",
modelName: "llama2",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
},
{
name: "model name with tag",
modelName: "llama2:7b",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
},
{
name: "model name with namespace",
modelName: "myuser/mymodel",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
},
{
name: "model name with namespace and tag",
modelName: "myuser/mymodel:v1",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
},
{
name: "fully qualified model name",
modelName: "registry.example.com/namespace/model:tag",
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := resolveManifestPath(tt.modelName)
for _, part := range tt.wantParts {
if !strings.Contains(got, part) {
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
}
}
})
}
}
func TestLayerInfo(t *testing.T) {
layer := LayerInfo{
Digest: "sha256:abc123",
Size: 1024,
MediaType: "application/vnd.ollama.image.tensor",
Name: "model.weight",
}
if layer.Digest != "sha256:abc123" {
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
}
if layer.Size != 1024 {
t.Errorf("Size = %d, want %d", layer.Size, 1024)
}
if layer.MediaType != "application/vnd.ollama.image.tensor" {
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
}
if layer.Name != "model.weight" {
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
}
}
func TestModelConfig(t *testing.T) {
config := ModelConfig{
ModelFormat: "safetensors",
Capabilities: []string{"completion", "chat"},
}
if config.ModelFormat != "safetensors" {
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
}
if len(config.Capabilities) != 2 {
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
}
}
func TestManifest(t *testing.T) {
manifest := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.oci.image.manifest.v1+json",
Config: ManifestLayer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: "sha256:config",
Size: 100,
},
Layers: []ManifestLayer{
{
MediaType: "application/vnd.ollama.image.tensor",
Digest: "sha256:layer1",
Size: 1000,
Name: "weight.bin",
},
},
}
if manifest.SchemaVersion != 2 {
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
}
if manifest.Config.Digest != "sha256:config" {
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
}
if len(manifest.Layers) != 1 {
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
}
if manifest.Layers[0].Name != "weight.bin" {
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
}
}
func TestShouldQuantize(t *testing.T) {
tests := []struct {
name string
tensor string
component string
want bool
}{
// VAE component should never be quantized
{"vae weight", "decoder.weight", "vae", false},
{"vae bias", "decoder.bias", "vae", false},
// Embeddings should not be quantized
{"embedding weight", "embed_tokens.weight", "", false},
{"embedding in name", "token_embedding.weight", "", false},
// Norms should not be quantized
{"layer norm", "layer_norm.weight", "", false},
{"rms norm", "rms_norm.weight", "", false},
{"ln prefix", "ln_1.weight", "", false},
{"layernorm in name", "input_layernorm.weight", "", false},
// Biases should not be quantized
{"bias tensor", "attention.bias", "", false},
{"proj bias", "o_proj.bias", "", false},
// Linear weights should be quantized
{"linear weight", "q_proj.weight", "", true},
{"attention weight", "self_attn.weight", "", true},
{"mlp weight", "mlp.gate_proj.weight", "", true},
// Transformer component weights should be quantized
{"transformer weight", "layers.0.weight", "transformer", true},
{"text_encoder weight", "encoder.weight", "text_encoder", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantize(tt.tensor, tt.component)
if got != tt.want {
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
}
})
}
}
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
{"small 2D weight", "small.weight", []int32{31, 31}, false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var quantizeRequested []bool
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
// Run with quantize enabled
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify quantize was passed to callback (will be false for small test tensor)
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}
// createMinimalImageGenModel creates a minimal diffusers-style model directory
func createMinimalImageGenModel(t *testing.T, dir string) {
t.Helper()
// Create model_index.json
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0644); err != nil {
t.Fatalf("failed to write model_index.json: %v", err)
}
// Create transformer directory with a safetensors file
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
// Create transformer config
transformerConfig := `{"hidden_size": 3072}`
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0644); err != nil {
t.Fatalf("failed to write transformer config: %v", err)
}
}
func TestCreateImageGenModel(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var manifestWritten bool
var manifestModelName string
var statusMessages []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-imagegen" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
}
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
dir := t.TempDir()
// Create only transformer without model_index.json
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing model_index.json, got nil")
}
}
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var quantizeRequested []bool
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}

190
x/imagegen/client/create.go Normal file
View File

@@ -0,0 +1,190 @@
// Package client provides client-side model creation for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
// cannot import server. This sub-package can import server because server doesn't
// import it.
//
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
// 1. Add proper API endpoints for tensor model creation
// 2. Move tensor extraction to server-side
// 3. Remove this package
// 4. Follow the same client→server pattern as regular model creation
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
// MinOllamaVersion is the minimum Ollama version required for image generation models.
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
// Create layer callback for config files
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// Create manifest writer callback
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg
spinner = progress.NewSpinner(status)
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created image generation model '%s'\n", modelName)
return nil
}

View File

@@ -1,4 +1,4 @@
package create
package imagegen
import (
"bytes"
@@ -12,11 +12,37 @@ import (
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// CreateImageGenModel imports an image generation model from a directory.
// IsTensorModelDir checks if the directory contains a tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes

22
x/imagegen/quantize.go Normal file
View File

@@ -0,0 +1,22 @@
package imagegen
import (
"io"
"strings"
)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ShouldQuantize returns true if a tensor should be quantized.
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
func ShouldQuantize(name, component string) bool {
if component == "vae" {
return false
}
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
return false
}
return strings.HasSuffix(name, ".weight")
}

View File

@@ -1,271 +0,0 @@
package server
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/imagegen"
)
// modelConfig represents the HuggingFace config.json structure
type modelConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
HiddenSize int `json:"hidden_size"`
NumHiddenLayers int `json:"num_hidden_layers"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
VocabSize int `json:"vocab_size"`
RMSNormEps float64 `json:"rms_norm_eps"`
RopeTheta float64 `json:"rope_theta"`
TorchDtype string `json:"torch_dtype"`
TextConfig *struct {
HiddenSize int `json:"hidden_size"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
NumHiddenLayers int `json:"num_hidden_layers"`
} `json:"text_config"`
}
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
// It reads the config.json layer and returns a map compatible with GGML's KV format.
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
var config modelConfig
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
// Calculate total tensor bytes from manifest layers
var totalBytes int64
var tensorCount int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalBytes += layer.Size
tensorCount++
}
}
return buildModelInfo(config, totalBytes, tensorCount), nil
}
// buildModelInfo constructs the model info map from config and tensor stats.
// This is separated for testability.
func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map[string]any {
// Determine architecture
arch := config.ModelType
if arch == "" && len(config.Architectures) > 0 {
// Convert HuggingFace architecture name to Ollama format
// e.g., "Gemma3ForCausalLM" -> "gemma3"
hfArch := config.Architectures[0]
arch = strings.ToLower(hfArch)
arch = strings.TrimSuffix(arch, "forcausallm")
arch = strings.TrimSuffix(arch, "forconditionalgeneration")
}
// Use text_config values if they exist (for multimodal models)
hiddenSize := config.HiddenSize
maxPosEmbed := config.MaxPositionEmbeddings
numLayers := config.NumHiddenLayers
if config.TextConfig != nil {
if config.TextConfig.HiddenSize > 0 {
hiddenSize = config.TextConfig.HiddenSize
}
if config.TextConfig.MaxPositionEmbeddings > 0 {
maxPosEmbed = config.TextConfig.MaxPositionEmbeddings
}
if config.TextConfig.NumHiddenLayers > 0 {
numLayers = config.TextConfig.NumHiddenLayers
}
}
// Get dtype to determine bytes per parameter for count calculation
dtype := config.TorchDtype
// Determine bytes per parameter based on dtype
var bytesPerParam int64 = 2 // default to float16/bfloat16
switch strings.ToLower(dtype) {
case "float32":
bytesPerParam = 4
case "float16", "bfloat16":
bytesPerParam = 2
case "int8", "uint8":
bytesPerParam = 1
}
// Subtract safetensors header overhead (88 bytes per tensor file)
// Each tensor is stored as a minimal safetensors file
totalBytes := totalTensorBytes - tensorCount*88
paramCount := totalBytes / bytesPerParam
info := map[string]any{
"general.architecture": arch,
}
if maxPosEmbed > 0 {
info[fmt.Sprintf("%s.context_length", arch)] = maxPosEmbed
}
if hiddenSize > 0 {
info[fmt.Sprintf("%s.embedding_length", arch)] = hiddenSize
}
if numLayers > 0 {
info[fmt.Sprintf("%s.block_count", arch)] = numLayers
}
if config.NumAttentionHeads > 0 {
info[fmt.Sprintf("%s.attention.head_count", arch)] = config.NumAttentionHeads
}
if config.NumKeyValueHeads > 0 {
info[fmt.Sprintf("%s.attention.head_count_kv", arch)] = config.NumKeyValueHeads
}
if config.IntermediateSize > 0 {
info[fmt.Sprintf("%s.feed_forward_length", arch)] = config.IntermediateSize
}
if config.VocabSize > 0 {
info[fmt.Sprintf("%s.vocab_size", arch)] = config.VocabSize
}
if paramCount > 0 {
info["general.parameter_count"] = paramCount
}
return info
}
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
return getTensorInfoFromManifest(manifest)
}
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
var tensors []api.Tensor
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType != "application/vnd.ollama.image.tensor" {
continue
}
// Read the safetensors header from the blob
blobPath := manifest.BlobPath(layer.Digest)
info, err := readSafetensorsHeader(blobPath)
if err != nil {
// Skip tensors we can't read
continue
}
// Convert shape from int to uint64
shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
shape[i] = uint64(s)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: info.Dtype,
Shape: shape,
})
}
return tensors, nil
}
// GetSafetensorsDtype returns the torch_dtype from config.json for a safetensors model.
func GetSafetensorsDtype(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
return "", fmt.Errorf("failed to read config.json: %w", err)
}
return cfg.TorchDtype, nil
}
// safetensorsTensorInfo holds metadata about a tensor from a safetensors header
type safetensorsTensorInfo struct {
Dtype string `json:"dtype"`
Shape []int64 `json:"shape"`
}
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
// Safetensors format: 8-byte header size (little endian) + JSON header + tensor data
func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return parseSafetensorsHeader(f)
}
// parseSafetensorsHeader parses a safetensors header from a reader.
// This is separated for testability.
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
// Read header size (8 bytes, little endian)
var headerSize uint64
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
return nil, fmt.Errorf("failed to read header size: %w", err)
}
// Sanity check - header shouldn't be too large
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header size too large: %d", headerSize)
}
// Read header JSON
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
// Parse as map of tensor name -> info
var header map[string]json.RawMessage
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
// Find the first (and should be only) tensor entry
for name, raw := range header {
if name == "__metadata__" {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
}
return &info, nil
}
return nil, fmt.Errorf("no tensor found in header")
}

View File

@@ -1,605 +0,0 @@
package server
import (
"bytes"
"encoding/binary"
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/x/imagegen"
)
func TestBuildModelInfo(t *testing.T) {
tests := []struct {
name string
config modelConfig
totalTensorBytes int64
tensorCount int64
wantArch string
wantContextLen int
wantEmbedLen int
wantBlockCount int
wantParamCount int64
}{
{
name: "gemma3 model with model_type",
config: modelConfig{
ModelType: "gemma3",
HiddenSize: 2560,
NumHiddenLayers: 34,
MaxPositionEmbeddings: 131072,
IntermediateSize: 10240,
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
VocabSize: 262144,
TorchDtype: "bfloat16",
},
totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
wantEmbedLen: 2560,
wantBlockCount: 34,
wantParamCount: 4_300_000_000,
},
{
name: "llama model with architectures array",
config: modelConfig{
Architectures: []string{"LlamaForCausalLM"},
HiddenSize: 4096,
NumHiddenLayers: 32,
MaxPositionEmbeddings: 4096,
IntermediateSize: 11008,
NumAttentionHeads: 32,
NumKeyValueHeads: 32,
VocabSize: 32000,
TorchDtype: "float16",
},
totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header
tensorCount: 1,
wantArch: "llama",
wantContextLen: 4096,
wantEmbedLen: 4096,
wantBlockCount: 32,
wantParamCount: 7_000_000_000,
},
{
name: "multimodal model with text_config",
config: modelConfig{
Architectures: []string{"Gemma3ForConditionalGeneration"},
HiddenSize: 1152, // vision hidden size
TextConfig: &struct {
HiddenSize int `json:"hidden_size"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
NumHiddenLayers int `json:"num_hidden_layers"`
}{
HiddenSize: 2560,
MaxPositionEmbeddings: 131072,
NumHiddenLayers: 34,
},
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
VocabSize: 262144,
TorchDtype: "bfloat16",
},
totalTensorBytes: 8_600_000_088,
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
wantEmbedLen: 2560,
wantBlockCount: 34,
wantParamCount: 4_300_000_000,
},
{
name: "float32 model",
config: modelConfig{
ModelType: "test",
HiddenSize: 512,
NumHiddenLayers: 6,
MaxPositionEmbeddings: 2048,
TorchDtype: "float32",
},
totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header
tensorCount: 1,
wantArch: "test",
wantContextLen: 2048,
wantEmbedLen: 512,
wantBlockCount: 6,
wantParamCount: 100_000_000,
},
{
name: "multiple tensors with header overhead",
config: modelConfig{
ModelType: "test",
HiddenSize: 256,
NumHiddenLayers: 4,
MaxPositionEmbeddings: 1024,
TorchDtype: "bfloat16",
},
totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes
tensorCount: 10,
wantArch: "test",
wantContextLen: 1024,
wantEmbedLen: 256,
wantBlockCount: 4,
wantParamCount: 1_000_000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := buildModelInfo(tt.config, tt.totalTensorBytes, tt.tensorCount)
// Check architecture
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
}
// Check context length
contextKey := tt.wantArch + ".context_length"
if contextLen, ok := info[contextKey].(int); !ok || contextLen != tt.wantContextLen {
t.Errorf("context_length = %v, want %v", info[contextKey], tt.wantContextLen)
}
// Check embedding length
embedKey := tt.wantArch + ".embedding_length"
if embedLen, ok := info[embedKey].(int); !ok || embedLen != tt.wantEmbedLen {
t.Errorf("embedding_length = %v, want %v", info[embedKey], tt.wantEmbedLen)
}
// Check block count
blockKey := tt.wantArch + ".block_count"
if blockCount, ok := info[blockKey].(int); !ok || blockCount != tt.wantBlockCount {
t.Errorf("block_count = %v, want %v", info[blockKey], tt.wantBlockCount)
}
// Check parameter count
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
}
})
}
}
func TestBuildModelInfo_ArchitectureConversion(t *testing.T) {
tests := []struct {
name string
architectures []string
modelType string
wantArch string
}{
{
name: "LlamaForCausalLM",
architectures: []string{"LlamaForCausalLM"},
wantArch: "llama",
},
{
name: "Gemma3ForCausalLM",
architectures: []string{"Gemma3ForCausalLM"},
wantArch: "gemma3",
},
{
name: "Gemma3ForConditionalGeneration",
architectures: []string{"Gemma3ForConditionalGeneration"},
wantArch: "gemma3",
},
{
name: "Qwen2ForCausalLM",
architectures: []string{"Qwen2ForCausalLM"},
wantArch: "qwen2",
},
{
name: "model_type takes precedence",
architectures: []string{"LlamaForCausalLM"},
modelType: "custom",
wantArch: "custom",
},
{
name: "empty architectures with model_type",
architectures: nil,
modelType: "mymodel",
wantArch: "mymodel",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := modelConfig{
Architectures: tt.architectures,
ModelType: tt.modelType,
}
info := buildModelInfo(config, 0, 0)
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
}
})
}
}
func TestBuildModelInfo_BytesPerParam(t *testing.T) {
tests := []struct {
name string
dtype string
totalBytes int64
tensorCount int64
wantParamCount int64
}{
{
name: "bfloat16",
dtype: "bfloat16",
totalBytes: 2_000_088, // 1M * 2 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float16",
dtype: "float16",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float32",
dtype: "float32",
totalBytes: 4_000_088, // 1M * 4 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "int8",
dtype: "int8",
totalBytes: 1_000_088, // 1M * 1 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "unknown dtype defaults to 2 bytes",
dtype: "unknown",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "empty dtype defaults to 2 bytes",
dtype: "",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := modelConfig{
ModelType: "test",
TorchDtype: tt.dtype,
}
info := buildModelInfo(config, tt.totalBytes, tt.tensorCount)
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
}
})
}
}
func TestParseSafetensorsHeader(t *testing.T) {
tests := []struct {
name string
header map[string]any
wantDtype string
wantShape []int64
wantErr bool
}{
{
name: "simple tensor",
header: map[string]any{
"weight": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 262144},
"data_offsets": []int64{0, 1342177280},
},
},
wantDtype: "BF16",
wantShape: []int64{2560, 262144},
},
{
name: "with metadata",
header: map[string]any{
"__metadata__": map[string]any{
"format": "pt",
},
"bias": map[string]any{
"dtype": "F32",
"shape": []int64{1024},
"data_offsets": []int64{0, 4096},
},
},
wantDtype: "F32",
wantShape: []int64{1024},
},
{
name: "float16 tensor",
header: map[string]any{
"layer.weight": map[string]any{
"dtype": "F16",
"shape": []int64{512, 512, 3, 3},
"data_offsets": []int64{0, 4718592},
},
},
wantDtype: "F16",
wantShape: []int64{512, 512, 3, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create safetensors format: 8-byte size + JSON header
headerJSON, err := json.Marshal(tt.header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
buf.Write(headerJSON)
info, err := parseSafetensorsHeader(&buf)
if (err != nil) != tt.wantErr {
t.Errorf("parseSafetensorsHeader() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if info.Dtype != tt.wantDtype {
t.Errorf("Dtype = %v, want %v", info.Dtype, tt.wantDtype)
}
if len(info.Shape) != len(tt.wantShape) {
t.Errorf("Shape length = %v, want %v", len(info.Shape), len(tt.wantShape))
} else {
for i, s := range info.Shape {
if s != tt.wantShape[i] {
t.Errorf("Shape[%d] = %v, want %v", i, s, tt.wantShape[i])
}
}
}
})
}
}
func TestParseSafetensorsHeader_Errors(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr string
}{
{
name: "empty data",
data: []byte{},
wantErr: "failed to read header size",
},
{
name: "truncated header size",
data: []byte{0x01, 0x02, 0x03},
wantErr: "failed to read header size",
},
{
name: "header size too large",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(2*1024*1024)) // 2MB
return buf.Bytes()
}(),
wantErr: "header size too large",
},
{
name: "truncated header",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(100))
buf.Write([]byte("short"))
return buf.Bytes()
}(),
wantErr: "failed to read header",
},
{
name: "invalid JSON",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(10))
buf.Write([]byte("not json!!"))
return buf.Bytes()
}(),
wantErr: "failed to parse header",
},
{
name: "no tensors in header",
data: func() []byte {
header := map[string]any{
"__metadata__": map[string]any{"format": "pt"},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
return buf.Bytes()
}(),
wantErr: "no tensor found in header",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := parseSafetensorsHeader(bytes.NewReader(tt.data))
if err == nil {
t.Error("expected error, got nil")
return
}
if !bytes.Contains([]byte(err.Error()), []byte(tt.wantErr)) {
t.Errorf("error = %v, want error containing %v", err, tt.wantErr)
}
})
}
}
func TestGetTensorInfoFromManifest(t *testing.T) {
// Create a temp directory for blobs
tempDir, err := os.MkdirTemp("", "ollama-test-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// Create test tensor blobs
tensors := []struct {
name string
digest string
dtype string
shape []int64
}{
{
name: "model.embed_tokens.weight",
digest: "sha256:abc123",
dtype: "BF16",
shape: []int64{262144, 2560},
},
{
name: "model.layers.0.self_attn.q_proj.weight",
digest: "sha256:def456",
dtype: "BF16",
shape: []int64{2560, 2560},
},
{
name: "model.norm.weight",
digest: "sha256:ghi789",
dtype: "F32",
shape: []int64{2560},
},
}
// Create blob files
var layers []imagegen.ManifestLayer
for _, tensor := range tensors {
// Create safetensors blob
header := map[string]any{
tensor.name: map[string]any{
"dtype": tensor.dtype,
"shape": tensor.shape,
"data_offsets": []int64{0, 1000},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
// Write blob file
blobName := "sha256-" + tensor.digest[7:]
blobPath := filepath.Join(tempDir, blobName)
if err := os.WriteFile(blobPath, buf.Bytes(), 0644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
layers = append(layers, imagegen.ManifestLayer{
MediaType: "application/vnd.ollama.image.tensor",
Digest: tensor.digest,
Size: int64(buf.Len() + 1000), // header + fake data
Name: tensor.name,
})
}
// Add a non-tensor layer (should be skipped)
layers = append(layers, imagegen.ManifestLayer{
MediaType: "application/vnd.ollama.image.json",
Digest: "sha256:config",
Size: 100,
Name: "config.json",
})
manifest := &imagegen.ModelManifest{
Manifest: &imagegen.Manifest{
Layers: layers,
},
BlobDir: tempDir,
}
result, err := getTensorInfoFromManifest(manifest)
if err != nil {
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
}
if len(result) != 3 {
t.Errorf("got %d tensors, want 3", len(result))
}
// Verify each tensor
for i, tensor := range tensors {
if i >= len(result) {
break
}
if result[i].Name != tensor.name {
t.Errorf("tensor[%d].Name = %v, want %v", i, result[i].Name, tensor.name)
}
if result[i].Type != tensor.dtype {
t.Errorf("tensor[%d].Type = %v, want %v", i, result[i].Type, tensor.dtype)
}
if len(result[i].Shape) != len(tensor.shape) {
t.Errorf("tensor[%d].Shape length = %v, want %v", i, len(result[i].Shape), len(tensor.shape))
}
}
}
func TestReadSafetensorsHeader(t *testing.T) {
// Create a temp file with a valid safetensors header
tempDir, err := os.MkdirTemp("", "ollama-test-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
header := map[string]any{
"test_tensor": map[string]any{
"dtype": "BF16",
"shape": []int64{1024, 768},
"data_offsets": []int64{0, 1572864},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
filePath := filepath.Join(tempDir, "test.safetensors")
if err := os.WriteFile(filePath, buf.Bytes(), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
info, err := readSafetensorsHeader(filePath)
if err != nil {
t.Fatalf("readSafetensorsHeader() error = %v", err)
}
if info.Dtype != "BF16" {
t.Errorf("Dtype = %v, want BF16", info.Dtype)
}
if len(info.Shape) != 2 || info.Shape[0] != 1024 || info.Shape[1] != 768 {
t.Errorf("Shape = %v, want [1024, 768]", info.Shape)
}
}
func TestReadSafetensorsHeader_FileNotFound(t *testing.T) {
_, err := readSafetensorsHeader("/nonexistent/path/file.safetensors")
if err == nil {
t.Error("expected error for nonexistent file")
}
}