Compare commits
33 Commits
mlx-gpu-cd
...
imagegen-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bd0d841f1 | ||
|
|
a077d996e3 | ||
|
|
c23d5095de | ||
|
|
7601f0e93e | ||
|
|
aad3f03890 | ||
|
|
55d0b6e8b9 | ||
|
|
38eac40d56 | ||
|
|
80f3f1bc25 | ||
|
|
b1a0db547b | ||
|
|
75d7b5f926 | ||
|
|
349d814814 | ||
|
|
c8743031e0 | ||
|
|
4adb9cf4bb | ||
|
|
74f475e735 | ||
|
|
875cecba74 | ||
|
|
7d411a4686 | ||
|
|
02a2401596 | ||
|
|
e4b488a7b5 | ||
|
|
98079ddd79 | ||
|
|
d70942f47b | ||
|
|
58e4701557 | ||
|
|
dbf47ee55a | ||
|
|
af7ea6e96e | ||
|
|
8f1e0140e7 | ||
|
|
35c3c9e3c2 | ||
|
|
d06acbcb19 | ||
|
|
9667c2282f | ||
|
|
a937a68317 | ||
|
|
2185112d84 | ||
|
|
91926601dc | ||
|
|
361d6c16c2 | ||
|
|
7e2496e88e | ||
|
|
5b84e29882 |
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -13,7 +13,7 @@ body:
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant log output
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
|
||||
6
.github/workflows/release.yaml
vendored
@@ -372,13 +372,17 @@ jobs:
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- name: Deduplicate CUDA libraries
|
||||
run: |
|
||||
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
endif()
|
||||
|
||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||
if(APPLE)
|
||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||
endif()
|
||||
|
||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||
@@ -189,13 +190,21 @@ if(MLX_ENGINE)
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
)
|
||||
|
||||
# Install the Metal library for macOS arm64 (must be colocated with the binary)
|
||||
# Metal backend is only built for arm64, not x86_64
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
|
||||
@@ -161,6 +161,9 @@ ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN mkdir -p dist/bin
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -182,6 +185,7 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
43
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
||||
|
||||
## Model library
|
||||
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
||||
|
||||
Here are some example models that can be downloaded:
|
||||
|
||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
## Building with MLX (experimental)
|
||||
|
||||
First build the MLX libraries:
|
||||
|
||||
```shell
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
|
||||
```shell
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
```
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
### Building MLX with CUDA
|
||||
|
||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
@@ -421,7 +454,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxBufferSize = 8 * format.MegaByte
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf io.Reader
|
||||
@@ -281,6 +281,26 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
|
||||
})
|
||||
}
|
||||
|
||||
// XGenerateResponseFunc is a function that [Client.XGenerate] invokes every time
|
||||
// a response is received from the experimental /x/generate endpoint.
|
||||
type XGenerateResponseFunc func(XGenerateResponse) error
|
||||
|
||||
// XGenerate generates images using the experimental /x/generate endpoint.
|
||||
// This endpoint is specifically designed for image generation models and
|
||||
// supports parameters like width, height, and steps.
|
||||
//
|
||||
// Experimental: This method may change or be removed in future versions.
|
||||
func (c *Client) XGenerate(ctx context.Context, req *XGenerateRequest, fn XGenerateResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/x/generate", req, func(bts []byte) error {
|
||||
var resp XGenerateResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn(resp)
|
||||
})
|
||||
}
|
||||
|
||||
// ChatResponseFunc is a function that [Client.Chat] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Chat] will stop generating and return this error.
|
||||
|
||||
57
api/types.go
@@ -56,6 +56,63 @@ func (e AuthorizationError) Error() string {
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
// XGenerateRequest describes a request sent by [Client.XGenerate] to the
|
||||
// experimental /x/generate endpoint for image generation.
|
||||
// Experimental: This struct may change or be removed in future versions.
|
||||
type XGenerateRequest struct {
|
||||
// Model is the model name for an image generation model.
|
||||
Model string `json:"model"`
|
||||
|
||||
// Prompt is the text prompt describing the image to generate.
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// Width is the width of the generated image in pixels.
|
||||
Width int32 `json:"width,omitempty"`
|
||||
|
||||
// Height is the height of the generated image in pixels.
|
||||
Height int32 `json:"height,omitempty"`
|
||||
|
||||
// Steps is the number of diffusion steps.
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
}
|
||||
|
||||
// XGenerateResponse is the response from the experimental /x/generate endpoint.
|
||||
// Experimental: This struct may change or be removed in future versions.
|
||||
type XGenerateResponse struct {
|
||||
// Model is the model that generated the image.
|
||||
Model string `json:"model"`
|
||||
|
||||
// CreatedAt is the time the response was generated.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Status describes the current phase (e.g., "generating image").
|
||||
Status string `json:"status,omitempty"`
|
||||
|
||||
// Total is the total number of steps for image generation.
|
||||
Total int64 `json:"total,omitempty"`
|
||||
|
||||
// Completed is the number of completed steps.
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
|
||||
// Images contains base64-encoded generated images (in final response).
|
||||
Images []string `json:"images,omitempty"`
|
||||
|
||||
// Done indicates whether the generation is complete.
|
||||
Done bool `json:"done"`
|
||||
|
||||
// DoneReason is the reason generation stopped.
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
|
||||
// TotalDuration is the total time spent generating the image.
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
|
||||
// LoadDuration is the time spent loading the model.
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
}
|
||||
|
||||
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
||||
// have to specify the Model and Prompt fields, all the other fields have
|
||||
// reasonable defaults for basic uses.
|
||||
|
||||
@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
|
||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||
@property(assign, nonatomic) BOOL updateAvailable;
|
||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
||||
@end
|
||||
|
||||
@implementation AppDelegate
|
||||
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||
// Register for system shutdown/restart notification so we can allow termination
|
||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
||||
addObserver:self
|
||||
selector:@selector(systemWillPowerOff:)
|
||||
name:NSWorkspaceWillPowerOffNotification
|
||||
object:nil];
|
||||
|
||||
// if we're in development mode, set the app icon
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
if (![bundlePath hasSuffix:@".app"]) {
|
||||
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
}
|
||||
|
||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
||||
// The system will call applicationShouldTerminate: after posting this notification.
|
||||
self.systemShutdownInProgress = YES;
|
||||
}
|
||||
|
||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||
// Allow termination if the system is shutting down or restarting
|
||||
if (self.systemShutdownInProgress) {
|
||||
return NSTerminateNow;
|
||||
}
|
||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
||||
[NSApp hide:nil];
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||
return NSTerminateCancel;
|
||||
|
||||
130
cmd/cmd.go
@@ -46,8 +46,9 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -93,14 +94,87 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Validate model name early to fail fast
|
||||
modelName := args[0]
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) || filename == "" {
|
||||
// No Modelfile specified or found - use default
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
}
|
||||
|
||||
// Parse the Modelfile
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
if !filepath.IsAbs(modelDir) && filename != "" {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
|
||||
}
|
||||
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: mfConfig,
|
||||
}, p)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
return imagegenclient.CreateModel(args[0], ".", p)
|
||||
if create.IsTensorModelDir(".") {
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: ".",
|
||||
Quantize: quantize,
|
||||
}, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
@@ -133,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = args[0]
|
||||
req.Model = modelName
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
@@ -464,14 +538,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
name := args[0]
|
||||
|
||||
// Check if this is a known image generation model (skip Show/Pull)
|
||||
if imagegen.HasTensorLayers(name) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
@@ -533,9 +599,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -565,7 +640,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
@@ -671,7 +746,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
msg := resp.Status
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
|
||||
}
|
||||
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@@ -837,11 +916,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
// Check if this is an image generation model
|
||||
if imagegen.HasTensorLayers(args[0]) {
|
||||
return imagegen.Show(args[0], os.Stdout)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1741,15 +1815,22 @@ func NewCLI() *cobra.Command {
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Skip server check for experimental mode (writes directly to disk)
|
||||
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
|
||||
return nil
|
||||
}
|
||||
return checkServerHeartbeat(cmd, args)
|
||||
},
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -1786,6 +1867,7 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowInfoImageGen(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" image \n" +
|
||||
"\n"
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushProgressMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
digest string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "uses status when provided",
|
||||
status: "uploading model",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "uploading model",
|
||||
},
|
||||
{
|
||||
name: "falls back to digest when status empty",
|
||||
status: "",
|
||||
digest: "sha256:abc123456789def",
|
||||
wantMsg: "pushing abc123456789...",
|
||||
},
|
||||
{
|
||||
name: "handles short digest gracefully",
|
||||
status: "",
|
||||
digest: "sha256:abc",
|
||||
wantMsg: "pushing sha256:abc...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := tt.status
|
||||
if msg == "" {
|
||||
if len(tt.digest) >= 19 {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
|
||||
} else {
|
||||
msg = fmt.Sprintf("pushing %s...", tt.digest)
|
||||
}
|
||||
}
|
||||
if msg != tt.wantMsg {
|
||||
t.Errorf("got %q, want %q", msg, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
// Test that modifications to original don't affect copy
|
||||
originalThink := &api.ThinkValue{Value: "original"}
|
||||
|
||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
83
docs/api.md
@@ -16,6 +16,7 @@
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Version](#version)
|
||||
- [Experimental: Generate an image](#generate-an-image-experimental)
|
||||
|
||||
## Conventions
|
||||
|
||||
@@ -1867,3 +1868,85 @@ curl http://localhost:11434/api/version
|
||||
"version": "0.5.1"
|
||||
}
|
||||
```
|
||||
|
||||
## Experimental Endpoints
|
||||
|
||||
### Generate an image (Experimental)
|
||||
|
||||
```
|
||||
POST /x/generate
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This endpoint is experimental and may change in future versions.
|
||||
|
||||
Generate an image using an image generation model. This endpoint is specifically designed for diffusion-based image generation models.
|
||||
|
||||
#### Parameters
|
||||
|
||||
- `model`: (required) the [model name](#model-names) of an image generation model
|
||||
- `prompt`: the text prompt describing the image to generate
|
||||
|
||||
Image generation parameters (optional):
|
||||
|
||||
- `width`: width of the generated image in pixels (default: model-specific, typically 1024)
|
||||
- `height`: height of the generated image in pixels (default: model-specific, typically 1024)
|
||||
- `steps`: number of diffusion steps (default: model-specific)
|
||||
|
||||
Other parameters:
|
||||
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
#### Response
|
||||
|
||||
The response is streamed as JSON objects showing generation progress:
|
||||
|
||||
- `status`: describes the current phase (e.g., "generating image")
|
||||
- `total`: total number of steps
|
||||
- `completed`: number of completed steps
|
||||
- `done`: whether generation is complete
|
||||
|
||||
The final response includes:
|
||||
|
||||
- `images`: array of base64-encoded generated images
|
||||
- `total_duration`: time spent generating the image
|
||||
- `load_duration`: time spent loading the model
|
||||
|
||||
#### Examples
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/x/generate -d '{
|
||||
"model": "flux",
|
||||
"prompt": "a sunset over mountains",
|
||||
"width": 1024,
|
||||
"height": 768
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response (streaming)
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "flux",
|
||||
"created_at": "2024-01-15T10:30:00.000000Z",
|
||||
"status": "generating image",
|
||||
"completed": 5,
|
||||
"total": 20,
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
##### Final Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "flux",
|
||||
"created_at": "2024-01-15T10:30:15.000000Z",
|
||||
"images": ["iVBORw0KGgoAAAANSUhEUg..."],
|
||||
"done": true,
|
||||
"total_duration": 15000000000,
|
||||
"load_duration": 2000000000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -21,6 +21,7 @@ ollama pull glm-4.7:cloud
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
@@ -247,12 +248,13 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const results = await client.webSearch({ query: "what is ollama?" });
|
||||
const results = await client.webSearch("what is ollama?");
|
||||
console.log(JSON.stringify(results, null, 2));
|
||||
```
|
||||
|
||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||
const fetchResult = await client.webFetch("https://ollama.com");
|
||||
console.log(JSON.stringify(fetchResult, null, 2));
|
||||
```
|
||||
|
||||
|
||||
@@ -111,7 +111,9 @@
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode"
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
|
||||
BIN
docs/images/marimo-add-model.png
Normal file
|
After Width: | Height: | Size: 174 KiB |
BIN
docs/images/marimo-chat.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
BIN
docs/images/marimo-code-completion.png
Normal file
|
After Width: | Height: | Size: 230 KiB |
BIN
docs/images/marimo-models.png
Normal file
|
After Width: | Height: | Size: 178 KiB |
BIN
docs/images/marimo-settings.png
Normal file
|
After Width: | Height: | Size: 186 KiB |
BIN
docs/images/onyx-login.png
Normal file
|
After Width: | Height: | Size: 100 KiB |
BIN
docs/images/onyx-ollama-form.png
Normal file
|
After Width: | Height: | Size: 306 KiB |
BIN
docs/images/onyx-ollama-llm.png
Normal file
|
After Width: | Height: | Size: 300 KiB |
BIN
docs/images/onyx-query.png
Normal file
|
After Width: | Height: | Size: 211 KiB |
@@ -25,6 +25,7 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
@@ -38,7 +39,7 @@ claude --model qwen3-coder
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
73
docs/integrations/marimo.mdx
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: marimo
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
||||
|
||||
```
|
||||
uvx marimo edit --sandbox notebook.py
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
||||
you can find and configure Ollama as an AI provider. For local use you
|
||||
would typically point the base url to `http://localhost:11434/v1`.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-settings.png"
|
||||
alt="Ollama settings in marimo"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-models.png"
|
||||
alt="Selecting an Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-add-model.png"
|
||||
alt="Adding a new Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-chat.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-code-completion.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Sign in to ollama cloud via `ollama signin`
|
||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
||||
3. You can now refer to this model in marimo!
|
||||
63
docs/integrations/onyx.mdx
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
title: Onyx
|
||||
---
|
||||
|
||||
## Overview
|
||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
||||
- Creating custom Agents
|
||||
- Web search
|
||||
- Deep Research
|
||||
- RAG over uploaded documents and connected apps
|
||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
||||
- MCP and OpenAPI Actions support
|
||||
- Image generation
|
||||
- User/Groups management, RBAC, SSO, etc.
|
||||
|
||||
Onyx can be deployed for single users or large organizations.
|
||||
|
||||
## Install Onyx
|
||||
|
||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
||||
|
||||
<Info>
|
||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
||||
</Info>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. Login to your Onyx deployment (create an account first).
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-login.png"
|
||||
alt="Onyx Login Page"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. In the set-up process select `Ollama` as the LLM provider.
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-llm.png"
|
||||
alt="Onyx Set Up Form"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Provide your **Ollama API URL** and select your models.
|
||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-form.png"
|
||||
alt="Selecting Ollama Models"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
||||
|
||||
## Send your first query
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-query.png"
|
||||
alt="Onyx Query Example"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Linux"
|
||||
title: Linux
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -40,8 +41,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -112,7 +113,11 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# Troubleshooting
|
||||
|
||||
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)
|
||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -1464,6 +1464,12 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1512,6 +1518,15 @@ type CompletionResponse struct {
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image contains base64-encoded image data for image generation
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Step is the current step in image generation
|
||||
Step int `json:"step,omitempty"`
|
||||
|
||||
// TotalSteps is the total number of steps for image generation
|
||||
TotalSteps int `json:"total_steps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -118,6 +118,9 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
c.Set("relax_thinking", true)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||
|
||||
@@ -582,3 +582,26 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var flagSet bool
|
||||
router := gin.New()
|
||||
router.Use(AnthropicMessagesMiddleware())
|
||||
router.POST("/v1/messages", func(c *gin.Context) {
|
||||
_, flagSet = c.Get("relax_thinking")
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if !flagSet {
|
||||
t.Error("expected relax_thinking flag to be set in context")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -49,6 +50,11 @@ type EmbedWriter struct {
|
||||
encodingFormat string
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
done bool
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
@@ -273,6 +279,36 @@ func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.XGenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Image generation doesn't support streaming in the OpenAI API sense,
|
||||
// so we only write the response when done with images
|
||||
if generateResponse.Done && len(generateResponse.Images) > 0 {
|
||||
w.done = true
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
@@ -392,6 +428,43 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq := openai.FromImageGenerationRequest(req)
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
@@ -441,6 +514,7 @@ type ResponsesWriter struct {
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
@@ -478,7 +552,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -523,11 +599,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
|
||||
@@ -961,3 +961,142 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerationsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.XGenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.XGenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image generation handler",
|
||||
body: `{
|
||||
"model": "flux",
|
||||
"prompt": "a cat"
|
||||
}`,
|
||||
req: api.XGenerateRequest{
|
||||
Model: "flux",
|
||||
Prompt: "a cat",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation with size",
|
||||
body: `{
|
||||
"model": "flux",
|
||||
"prompt": "a dog",
|
||||
"size": "512x512"
|
||||
}`,
|
||||
req: api.XGenerateRequest{
|
||||
Model: "flux",
|
||||
Prompt: "a dog",
|
||||
Width: 512,
|
||||
Height: 512,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing prompt error",
|
||||
body: `{
|
||||
"model": "flux"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"prompt": "a cat"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatalf("requests did not match\nExpected: %+v\nActual: %+v", tc.req, *capturedRequest)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatalf("errors did not match\nExpected: %+v\nActual: %+v", tc.err, errResp)
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageWriterIntegration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("transforms generate response to openai format", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware())
|
||||
router.POST("/api/generate", func(c *gin.Context) {
|
||||
// Simulate an image generation response
|
||||
generateResponse := api.XGenerateResponse{
|
||||
Done: true,
|
||||
CreatedAt: time.Now(),
|
||||
Images: []string{"base64encodedimage"},
|
||||
}
|
||||
c.JSON(http.StatusOK, generateResponse)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(`{"model":"flux","prompt":"a cat"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var response openai.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(response.Data) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(response.Data))
|
||||
}
|
||||
if response.Data[0].B64JSON != "base64encodedimage" {
|
||||
t.Fatalf("expected image data 'base64encodedimage', got '%s'", response.Data[0].B64JSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
@@ -733,3 +737,53 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama XGenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.XGenerateRequest {
|
||||
req := api.XGenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama XGenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.XGenerateResponse) ImageGenerationResponse {
|
||||
data := make([]ImageURLOrData, 0)
|
||||
for _, img := range resp.Images {
|
||||
data = append(data, ImageURLOrData{B64JSON: img})
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -265,9 +266,9 @@ type ResponsesText struct {
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Description *string `json:"description"` // nullable but required
|
||||
Strict *bool `json:"strict"` // nullable but required
|
||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
||||
}
|
||||
|
||||
type ResponsesRequest struct {
|
||||
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var description string
|
||||
if t.Description != nil {
|
||||
description = *t.Description
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: t.Type,
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
||||
|
||||
// Response types for the Responses API
|
||||
|
||||
// ResponsesTextField represents the text output configuration in the response.
|
||||
type ResponsesTextField struct {
|
||||
Format ResponsesTextFormat `json:"format"`
|
||||
}
|
||||
|
||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
||||
type ResponsesReasoningOutput struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Summary *string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError represents an error in the response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||
// requires additional plumbing to find the effective values since the
|
||||
// defaults can come from the model or the request
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at"`
|
||||
Status string `json:"status"`
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
||||
Model string `json:"model"`
|
||||
PreviousResponseID *string `json:"previous_response_id"`
|
||||
Instructions *string `json:"instructions"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Error *ResponsesError `json:"error"`
|
||||
Tools []ResponsesTool `json:"tools"`
|
||||
ToolChoice any `json:"tool_choice"`
|
||||
Truncation string `json:"truncation"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
Text ResponsesTextField `json:"text"`
|
||||
TopP float64 `json:"top_p"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
||||
Usage *ResponsesUsage `json:"usage"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
MaxToolCalls *int `json:"max_tool_calls"`
|
||||
Store bool `json:"store"`
|
||||
Background bool `json:"background"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
SafetyIdentifier *string `json:"safety_identifier"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
||||
}
|
||||
|
||||
type ResponsesOutputItem struct {
|
||||
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Annotations []any `json:"annotations"`
|
||||
Logprobs []any `json:"logprobs"`
|
||||
}
|
||||
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
||||
func derefFloat64(p *float64, def float64) float64 {
|
||||
if p != nil {
|
||||
return *p
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
||||
// The request is used to echo back request parameters in the response.
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
||||
var output []ResponsesOutputItem
|
||||
|
||||
// Add reasoning item if thinking is present
|
||||
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
output = append(output, ResponsesOutputItem{
|
||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
Role: "assistant",
|
||||
Content: []ResponsesOutputContent{
|
||||
{
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Annotations: []any{},
|
||||
Logprobs: []any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var instructions *string
|
||||
if request.Instructions != "" {
|
||||
instructions = &request.Instructions
|
||||
}
|
||||
|
||||
// Build truncation with default
|
||||
truncation := "disabled"
|
||||
if request.Truncation != nil {
|
||||
truncation = *request.Truncation
|
||||
}
|
||||
|
||||
tools := request.Tools
|
||||
if tools == nil {
|
||||
tools = []ResponsesTool{}
|
||||
}
|
||||
|
||||
text := ResponsesTextField{
|
||||
Format: ResponsesTextFormat{Type: "text"},
|
||||
}
|
||||
if request.Text != nil && request.Text.Format != nil {
|
||||
text.Format = *request.Text.Format
|
||||
}
|
||||
|
||||
// Build reasoning output from request
|
||||
var reasoning *ResponsesReasoningOutput
|
||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
||||
reasoning = &ResponsesReasoningOutput{}
|
||||
if request.Reasoning.Effort != "" {
|
||||
reasoning.Effort = &request.Reasoning.Effort
|
||||
}
|
||||
if request.Reasoning.Summary != "" {
|
||||
reasoning.Summary = &request.Reasoning.Summary
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesResponse{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
Status: "completed",
|
||||
Model: model,
|
||||
Output: output,
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
CompletedAt: nil, // Set by middleware when writing final response
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil, // Only populated if response incomplete
|
||||
Model: model,
|
||||
PreviousResponseID: nil, // Not supported
|
||||
Instructions: instructions,
|
||||
Output: output,
|
||||
Error: nil, // Only populated on failure
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // Default value
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: true, // Default value
|
||||
Text: text,
|
||||
TopP: derefFloat64(request.TopP, 1.0),
|
||||
PresencePenalty: 0, // Default value
|
||||
FrequencyPenalty: 0, // Default value
|
||||
TopLogprobs: 0, // Default value
|
||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
||||
Reasoning: reasoning,
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: chatResponse.PromptEvalCount,
|
||||
OutputTokens: chatResponse.EvalCount,
|
||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||
// TODO(drifkin): wire through the actual values
|
||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
||||
// TODO(drifkin): wire through the actual values
|
||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
||||
},
|
||||
MaxOutputTokens: request.MaxOutputTokens,
|
||||
MaxToolCalls: nil, // Not supported
|
||||
Store: false, // We don't store responses
|
||||
Background: request.Background,
|
||||
ServiceTier: "default", // Default value
|
||||
Metadata: map[string]any{},
|
||||
SafetyIdentifier: nil, // Not supported
|
||||
PromptCacheKey: nil, // Not supported
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
|
||||
responseID string
|
||||
itemID string
|
||||
model string
|
||||
request ResponsesRequest
|
||||
|
||||
// State tracking (mutated across Process calls)
|
||||
firstWrite bool
|
||||
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
||||
}
|
||||
|
||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
||||
return &ResponsesStreamConverter{
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
model: model,
|
||||
request: request,
|
||||
firstWrite: true,
|
||||
}
|
||||
}
|
||||
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
||||
return events
|
||||
}
|
||||
|
||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
||||
var instructions any = nil
|
||||
if c.request.Instructions != "" {
|
||||
instructions = c.request.Instructions
|
||||
}
|
||||
|
||||
truncation := "disabled"
|
||||
if c.request.Truncation != nil {
|
||||
truncation = *c.request.Truncation
|
||||
}
|
||||
|
||||
var tools []any
|
||||
if c.request.Tools != nil {
|
||||
for _, t := range c.request.Tools {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"strict": t.Strict,
|
||||
"parameters": t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if tools == nil {
|
||||
tools = []any{}
|
||||
}
|
||||
|
||||
textFormat := map[string]any{"type": "text"}
|
||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
||||
textFormat = map[string]any{
|
||||
"type": c.request.Text.Format.Type,
|
||||
}
|
||||
if c.request.Text.Format.Name != "" {
|
||||
textFormat["name"] = c.request.Text.Format.Name
|
||||
}
|
||||
if c.request.Text.Format.Schema != nil {
|
||||
textFormat["schema"] = c.request.Text.Format.Schema
|
||||
}
|
||||
if c.request.Text.Format.Strict != nil {
|
||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
||||
}
|
||||
}
|
||||
|
||||
var reasoning any = nil
|
||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
||||
r := map[string]any{}
|
||||
if c.request.Reasoning.Effort != "" {
|
||||
r["effort"] = c.request.Reasoning.Effort
|
||||
} else {
|
||||
r["effort"] = nil
|
||||
}
|
||||
if c.request.Reasoning.Summary != "" {
|
||||
r["summary"] = c.request.Reasoning.Summary
|
||||
} else {
|
||||
r["summary"] = nil
|
||||
}
|
||||
reasoning = r
|
||||
}
|
||||
|
||||
// Build top_p and temperature with defaults
|
||||
topP := 1.0
|
||||
if c.request.TopP != nil {
|
||||
topP = *c.request.TopP
|
||||
}
|
||||
temperature := 1.0
|
||||
if c.request.Temperature != nil {
|
||||
temperature = *c.request.Temperature
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"completed_at": nil,
|
||||
"status": status,
|
||||
"incomplete_details": nil,
|
||||
"model": c.model,
|
||||
"previous_response_id": nil,
|
||||
"instructions": instructions,
|
||||
"output": output,
|
||||
"error": nil,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"truncation": truncation,
|
||||
"parallel_tool_calls": true,
|
||||
"text": map[string]any{"format": textFormat},
|
||||
"top_p": topP,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"top_logprobs": 0,
|
||||
"temperature": temperature,
|
||||
"reasoning": reasoning,
|
||||
"usage": usage,
|
||||
"max_output_tokens": c.request.MaxOutputTokens,
|
||||
"max_tool_calls": nil,
|
||||
"store": false,
|
||||
"background": c.request.Background,
|
||||
"service_tier": "default",
|
||||
"metadata": map[string]any{},
|
||||
"safety_identifier": nil,
|
||||
"prompt_cache_key": nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.created", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.in_progress", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
||||
|
||||
// Emit delta
|
||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"delta": thinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"delta": thinking,
|
||||
}))
|
||||
|
||||
// TODO(drifkin): consider adding
|
||||
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
||||
|
||||
events := []ResponsesStreamEvent{
|
||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"text": c.accumulatedThinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"text": c.accumulatedThinking,
|
||||
}),
|
||||
c.newEvent("response.output_item.done", map[string]any{
|
||||
"output_index": c.outputIndex,
|
||||
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": c.contentIndex,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"delta": content,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
return events
|
||||
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
})
|
||||
}
|
||||
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"text": c.accumulatedText,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
// response.content_part.done
|
||||
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// response.completed
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": c.buildFinalOutput(),
|
||||
"usage": map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
usage := map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
"output_tokens_details": map[string]any{
|
||||
"reasoning_tokens": 0,
|
||||
},
|
||||
}
|
||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
||||
response["completed_at"] = time.Now().Unix()
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": response,
|
||||
}))
|
||||
|
||||
return events
|
||||
|
||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with content
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with thinking
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
||||
Content: "The answer is 42",
|
||||
},
|
||||
Done: true,
|
||||
})
|
||||
}, ResponsesRequest{})
|
||||
|
||||
// Should have 2 output items: reasoning + message
|
||||
if len(response.Output) != 2 {
|
||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
// Verify that response.output_item.done includes content field for messages
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.completed includes the output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// Process some content
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.created includes an empty output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hi"},
|
||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
// Verify that events include incrementing sequence numbers
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hello"},
|
||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
// Verify that function call items include status field
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -36,10 +37,11 @@ type Terminal struct {
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
pastedLines []string
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
i.pastedLines = nil
|
||||
i.Prompt.UseAlt = false
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
buf.Remove()
|
||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
||||
lastIdx := len(i.pastedLines) - 1
|
||||
prevLine := i.pastedLines[lastIdx]
|
||||
i.pastedLines = i.pastedLines[:lastIdx]
|
||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
||||
if len(i.pastedLines) == 0 {
|
||||
fmt.Print(i.Prompt.Prompt)
|
||||
i.Prompt.UseAlt = false
|
||||
} else {
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
}
|
||||
for _, r := range prevLine {
|
||||
buf.Add(r)
|
||||
}
|
||||
} else {
|
||||
buf.Remove()
|
||||
}
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharEnter, CharCtrlJ:
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
||||
i.pastedLines = nil
|
||||
}
|
||||
if output != "" {
|
||||
i.History.Add(output)
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
i.Prompt.UseAlt = false
|
||||
|
||||
return output, nil
|
||||
default:
|
||||
|
||||
@@ -73,7 +73,7 @@ _build_darwin() {
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
@@ -82,19 +82,19 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/imagegen
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
# create a temporary zip for notarization
|
||||
TEMP=$(mktemp -u).zip
|
||||
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||
rm -f "$TEMP"
|
||||
fi
|
||||
|
||||
@@ -154,38 +154,40 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
# Copy MLX metallib (architecture-independent, just use arm64 version)
|
||||
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
fi
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
|
||||
rm -f dist/Ollama.dmg
|
||||
|
||||
@@ -206,7 +208,7 @@ _build_macapp() {
|
||||
rm -f dist/rw*.dmg
|
||||
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
$(xcrun -f stapler) staple dist/Ollama.dmg
|
||||
else
|
||||
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"
|
||||
|
||||
@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
|
||||
60
scripts/deduplicate_cuda_libs.sh
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/bin/sh
|
||||
#
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
# This script finds identical .so* files in mlx_cuda_* directories that exist
|
||||
# in corresponding cuda_* directories and replaces them with symlinks.
|
||||
#
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
echo "ERROR: No directory specified" >&2
|
||||
echo "Usage: $0 <base_directory>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
base_dir="$1"
|
||||
|
||||
if [ ! -d "${base_dir}" ]; then
|
||||
echo "ERROR: Directory ${base_dir} does not exist" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo "Deduplication complete"
|
||||
@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||
redirectURL, err := challenge.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||
if redirectURL.Host != originalHost {
|
||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
|
||||
|
||||
113
server/auth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
realm string
|
||||
originalHost string
|
||||
wantMismatch bool
|
||||
}{
|
||||
{"https://example.com/token", "example.com", false},
|
||||
{"https://example.com/token", "other.com", true},
|
||||
{"https://example.com/token", "localhost:8000", true},
|
||||
{"https://localhost:5000/token", "localhost:5000", false},
|
||||
{"https://localhost:5000/token", "localhost:6000", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.originalHost, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||
|
||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||
if tt.wantMismatch && !isMismatch {
|
||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||
}
|
||||
if !tt.wantMismatch && isMismatch {
|
||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantRealm, wantService, wantScope string
|
||||
}{
|
||||
{
|
||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||
},
|
||||
{
|
||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||
},
|
||||
{"", "", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := parseRegistryChallenge(tt.input)
|
||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||
tt.input, result.Realm, result.Service, result.Scope,
|
||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURL(t *testing.T) {
|
||||
challenge := registryChallenge{
|
||||
Realm: "https://auth.example.com/token",
|
||||
Service: "registry",
|
||||
Scope: "repo:foo:pull repo:bar:push",
|
||||
}
|
||||
|
||||
u, err := challenge.URL()
|
||||
if err != nil {
|
||||
t.Fatalf("URL() error: %v", err)
|
||||
}
|
||||
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||
}
|
||||
if u.Path != "/token" {
|
||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
if q.Get("service") != "registry" {
|
||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||
}
|
||||
if scopes := q["scope"]; len(scopes) != 2 {
|
||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||
}
|
||||
if q.Get("ts") == "" {
|
||||
t.Error("missing ts")
|
||||
}
|
||||
if q.Get("nonce") == "" {
|
||||
t.Error("missing nonce")
|
||||
}
|
||||
|
||||
// Nonces should differ between calls
|
||||
u2, _ := challenge.URL()
|
||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||
t.Error("nonce should be unique per call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||
challenge := registryChallenge{Realm: "://invalid"}
|
||||
if _, err := challenge.URL(); err == nil {
|
||||
t.Error("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
|
||||
// Handle authentication error with one retry
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
|
||||
}
|
||||
|
||||
func (m *Manifest) RemoveLayers() error {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest != "" {
|
||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
ms, err := Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build set of digests still in use by other manifests
|
||||
inUse := make(map[string]struct{})
|
||||
for _, other := range ms {
|
||||
for _, layer := range append(other.Layers, other.Config) {
|
||||
if layer.Digest != "" {
|
||||
inUse[layer.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove layers not used by any other manifest
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest == "" {
|
||||
continue
|
||||
}
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
198
server/routes.go
@@ -51,7 +51,7 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,29 +164,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -214,12 +191,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -1124,6 +1095,31 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
QuantizationLevel: m.Config.FileType,
|
||||
}
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
modelDetails.QuantizationLevel = info.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
m.System = req.System
|
||||
}
|
||||
@@ -1206,6 +1202,30 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1567,6 +1587,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Experimental: image generation endpoint (may change or be removed)
|
||||
r.POST("/x/generate", s.XGenerateHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
@@ -1574,13 +1597,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoint (uses experimental /x/generate)
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.XGenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -1889,6 +1911,106 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
// XGenerateHandler handles the experimental /x/generate endpoint for image generation.
|
||||
// This endpoint is experimental and may change or be removed in future versions.
|
||||
func (s *Server) XGenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
var req api.XGenerateRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Check that this is an image generation model
|
||||
if !slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support image generation", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule the runner
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{model.CapabilityImageGeneration}, nil, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// Handle load-only request
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.XGenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.XGenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: cr.Done,
|
||||
}
|
||||
|
||||
// Image generation progress
|
||||
if cr.TotalSteps > 0 {
|
||||
res.Status = "generating image"
|
||||
res.Completed = int64(cr.Step)
|
||||
res.Total = int64(cr.TotalSteps)
|
||||
}
|
||||
|
||||
// Final image
|
||||
if cr.Image != "" {
|
||||
res.Images = []string{cr.Image}
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
@@ -2059,8 +2181,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||
if _, ok := c.Get("relax_thinking"); ok {
|
||||
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
|
||||
req.Think = nil
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Load both an imagegen runner and a language model runner
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
||||
modelPath: "/fake/flux/model",
|
||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
langModelRunner := &runnerRef{
|
||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
||||
modelPath: "/fake/llama3/model",
|
||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify both are loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify updateFreeSpace accounts for both
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
||||
TotalMemory: 24 * format.GigaByte,
|
||||
FreeMemory: 24 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Free memory should be reduced by both models
|
||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
24
x/README.md
@@ -1,24 +0,0 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
||||
|
||||
```
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install --component MLX
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
||||
|
||||
## Image Generation
|
||||
|
||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
||||
|
||||
```
|
||||
go build -o imagegen ./x/imagegen/cmd/engine
|
||||
```
|
||||
@@ -41,6 +41,7 @@ var optionLabels = []string{
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"web_search": "Web Search",
|
||||
"web_fetch": "Web Fetch",
|
||||
}
|
||||
|
||||
// ToolDisplayName returns the human-readable display name for a tool.
|
||||
@@ -565,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
// For web fetch, show URL and internet notice
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||
if len(args) > 0 {
|
||||
@@ -1017,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
}
|
||||
}
|
||||
|
||||
if toolName == "web_fetch" {
|
||||
if url, ok := args["url"].(string); ok {
|
||||
// Truncate long URLs
|
||||
if len(url) > 50 {
|
||||
url = url[:47] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||
}
|
||||
|
||||
|
||||
308
x/cmd/run.go
@@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -130,6 +131,7 @@ type RunOptions struct {
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
Verbose bool
|
||||
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
@@ -178,6 +180,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||
var latest api.ChatResponse
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
@@ -187,6 +190,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
latest = response
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
@@ -483,6 +487,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if opts.Verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
@@ -634,12 +642,13 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
// If yoloMode is true, all tool approvals are skipped.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
||||
// If enableWebsearch is true, the web search tool is registered.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -660,6 +669,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
|
||||
// Register web search and web fetch tools if enabled via flag
|
||||
if enableWebsearch {
|
||||
toolRegistry.RegisterWebSearch()
|
||||
toolRegistry.RegisterWebFetch()
|
||||
}
|
||||
|
||||
if toolRegistry.Has("bash") {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
||||
@@ -667,6 +682,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
|
||||
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
}
|
||||
@@ -677,6 +697,8 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -688,6 +710,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
continue
|
||||
case err != nil:
|
||||
@@ -707,6 +730,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
continue
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load Load a different model")
|
||||
fmt.Fprintln(os.Stderr, " /save Save session as a model")
|
||||
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
@@ -716,6 +743,280 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
switch args[1] {
|
||||
case "history":
|
||||
scanner.HistoryEnable()
|
||||
case "nohistory":
|
||||
scanner.HistoryDisable()
|
||||
case "wordwrap":
|
||||
wordWrap = true
|
||||
fmt.Println("Set 'wordwrap' mode.")
|
||||
case "nowordwrap":
|
||||
wordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
think = &thinkValue
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
think = &api.ThinkValue{Value: false}
|
||||
// Check if model supports thinking
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
} else {
|
||||
format = args[2]
|
||||
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||
}
|
||||
case "noformat":
|
||||
format = ""
|
||||
fmt.Println("Disabled format.")
|
||||
case "parameter":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /set parameter <name> <value>")
|
||||
continue
|
||||
}
|
||||
params := args[3:]
|
||||
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't set parameter: %q\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message>")
|
||||
continue
|
||||
}
|
||||
|
||||
system = strings.Join(args[2:], " ")
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: modelName,
|
||||
Options: options,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model")
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, " Model\n")
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
|
||||
if resp.Details.Family != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
|
||||
}
|
||||
if resp.Details.ParameterSize != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
|
||||
}
|
||||
if resp.Details.QuantizationLevel != "" {
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
|
||||
}
|
||||
if len(resp.Capabilities) > 0 {
|
||||
caps := make([]string, len(resp.Capabilities))
|
||||
for i, c := range resp.Capabilities {
|
||||
caps[i] = string(c)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
|
||||
}
|
||||
fmt.Fprintln(os.Stderr)
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Println("No license was specified for this model.")
|
||||
} else {
|
||||
fmt.Println(resp.License)
|
||||
}
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println(" No additional parameters were specified.")
|
||||
} else {
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
}
|
||||
if len(options) > 0 {
|
||||
fmt.Println("\nUser defined parameters:")
|
||||
for k, v := range options {
|
||||
fmt.Printf(" %-30s %v\n", k, v)
|
||||
}
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
case system != "":
|
||||
fmt.Println(system + "\n")
|
||||
case resp.System != "":
|
||||
fmt.Println(resp.System + "\n")
|
||||
default:
|
||||
fmt.Println("No system message was specified for this model.")
|
||||
}
|
||||
case "template":
|
||||
if resp.Template != "" {
|
||||
fmt.Println(resp.Template)
|
||||
} else {
|
||||
fmt.Println("No prompt template was specified for this model.")
|
||||
}
|
||||
default:
|
||||
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /load <modelname>")
|
||||
continue
|
||||
}
|
||||
newModelName := args[1]
|
||||
fmt.Printf("Loading model '%s'\n", newModelName)
|
||||
|
||||
// Create progress spinner
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
// Get client
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if model exists and get its info
|
||||
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For cloud models, no need to preload
|
||||
if info.RemoteHost == "" {
|
||||
// Preload the model by sending an empty generate request
|
||||
req := &api.GenerateRequest{
|
||||
Model: newModelName,
|
||||
Think: think,
|
||||
}
|
||||
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
p.StopAndClear()
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||
} else if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("error loading model: %v\n", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
p.StopAndClear()
|
||||
modelName = newModelName
|
||||
messages = []api.Message{}
|
||||
approval.Reset()
|
||||
continue
|
||||
case strings.HasPrefix(line, "/save"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage: /save <modelname>")
|
||||
continue
|
||||
}
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
continue
|
||||
}
|
||||
req := &api.CreateRequest{
|
||||
Model: args[1],
|
||||
From: modelName,
|
||||
Parameters: options,
|
||||
Messages: messages,
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
continue
|
||||
@@ -727,10 +1028,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
@@ -738,6 +1041,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
Verbose: verbose,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
|
||||
282
x/create/client/create.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Package client provides client-side model creation for safetensors-based models.
|
||||
//
|
||||
// This package is in x/ because the safetensors model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/create, so x/create
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// ModelfileConfig holds configuration extracted from a Modelfile.
|
||||
type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "fp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
|
||||
func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Detect model type
|
||||
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
|
||||
isImageGen := create.IsTensorModelDir(opts.ModelDir)
|
||||
|
||||
if !isSafetensors && !isImageGen {
|
||||
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
|
||||
}
|
||||
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
capabilities = []string{"image"}
|
||||
}
|
||||
|
||||
// Set up progress spinner
|
||||
statusMsg := "importing " + modelType
|
||||
spinner := progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
statusMsg = msg
|
||||
spinner = progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
var err error
|
||||
if isSafetensors {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
|
||||
return create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
|
||||
// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias).
|
||||
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
|
||||
return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if quantize != "" {
|
||||
return createQuantizedLayers(r, name, dtype, shape, quantize)
|
||||
}
|
||||
return createUnquantizedLayer(r, name)
|
||||
}
|
||||
}
|
||||
|
||||
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
|
||||
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []create.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale",
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias",
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []create.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: capabilities,
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
serverLayers = append(serverLayers, server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// Add Modelfile layers if present
|
||||
if opts.Modelfile != nil {
|
||||
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
146
x/create/client/create_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelfileConfig(t *testing.T) {
|
||||
// Test that ModelfileConfig struct works as expected
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
|
||||
}
|
||||
if config.System != "You are a helpful assistant." {
|
||||
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
|
||||
}
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
config := &ModelfileConfig{}
|
||||
|
||||
if config.Template != "" {
|
||||
t.Errorf("Template should be empty, got %q", config.Template)
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Errorf("System should be empty, got %q", config.System)
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
// Test config with only some fields set
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
// System and License intentionally empty
|
||||
}
|
||||
|
||||
if config.Template == "" {
|
||||
t.Error("Template should not be empty")
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Error("System should be empty")
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
// Verify the minimum version constant is set
|
||||
if MinOllamaVersion == "" {
|
||||
t.Error("MinOllamaVersion should not be empty")
|
||||
}
|
||||
if MinOllamaVersion != "0.14.0" {
|
||||
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_InvalidDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for invalid directory
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: "/nonexistent/path",
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for directory without safetensors
|
||||
dir := t.TempDir()
|
||||
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: dir,
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
Modelfile: &ModelfileConfig{
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
},
|
||||
}
|
||||
|
||||
if opts.ModelName != "my-model" {
|
||||
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
|
||||
}
|
||||
if opts.ModelDir != "/path/to/model" {
|
||||
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
|
||||
}
|
||||
if opts.Quantize != "fp8" {
|
||||
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
|
||||
}
|
||||
if opts.Modelfile == nil {
|
||||
t.Error("Modelfile should not be nil")
|
||||
}
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "test",
|
||||
ModelDir: "/tmp",
|
||||
}
|
||||
|
||||
// Quantize should default to empty
|
||||
if opts.Quantize != "" {
|
||||
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
|
||||
}
|
||||
|
||||
// Modelfile should default to nil
|
||||
if opts.Modelfile != nil {
|
||||
t.Error("Modelfile should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
supported := QuantizeSupported()
|
||||
|
||||
// In non-mlx builds, this should be false
|
||||
// We can't easily test both cases, so just verify it returns something
|
||||
_ = supported
|
||||
}
|
||||
127
x/create/client/quantize.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
if _, err := io.Copy(tmpFile, r); err != nil {
|
||||
tmpFile.Close()
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Load the tensor using MLX's native loader
|
||||
st, err := mlx.LoadSafetensorsNative(tmpPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
|
||||
}
|
||||
defer st.Free()
|
||||
|
||||
// Get the tensor (it's stored as "data" in our minimal safetensors format)
|
||||
arr := st.Get("data")
|
||||
if arr == nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
|
||||
}
|
||||
|
||||
// Convert to BFloat16 if needed (quantize expects float type)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
mlx.Eval(qweight, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qweight, scales)
|
||||
}
|
||||
|
||||
// Get shapes
|
||||
qweightShape = qweight.Shape()
|
||||
scalesShape = scales.Shape()
|
||||
|
||||
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
|
||||
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
|
||||
defer os.Remove(qweightPath)
|
||||
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
|
||||
}
|
||||
qweightData, err = os.ReadFile(qweightPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
|
||||
}
|
||||
|
||||
// Save scales using MLX's native safetensors
|
||||
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
|
||||
defer os.Remove(scalesPath)
|
||||
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
|
||||
}
|
||||
scalesData, err = os.ReadFile(scalesPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
|
||||
}
|
||||
|
||||
// Affine mode returns qbiases for zero-point offset
|
||||
if qbiases != nil {
|
||||
qbiasShape = qbiases.Shape()
|
||||
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
|
||||
defer os.Remove(qbiasPath)
|
||||
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
|
||||
}
|
||||
qbiasData, err = os.ReadFile(qbiasPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
func QuantizeSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||
func ensureTempDir() string {
|
||||
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
|
||||
os.MkdirAll(tmpDir, 0755)
|
||||
return tmpDir
|
||||
}
|
||||
18
x/create/client/quantize_stub.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !mlx
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
func QuantizeSupported() bool {
|
||||
return false
|
||||
}
|
||||
399
x/create/create.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// ModelConfig represents the config blob stored with a model.
|
||||
type ModelConfig struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
|
||||
// Manifest represents the manifest JSON structure.
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config ManifestLayer `json:"config"`
|
||||
Layers []ManifestLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
type ManifestLayer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// defaultManifestDir returns the manifest storage directory.
|
||||
func defaultManifestDir() string {
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// defaultBlobDir returns the blob storage directory.
|
||||
func defaultBlobDir() string {
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// loadManifest loads a manifest for the given model name.
|
||||
func loadManifest(modelName string) (*Manifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// loadModelConfig loads the config blob for a model.
|
||||
func loadModelConfig(modelName string) (*ModelConfig, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the config blob
|
||||
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config ModelConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModel checks if a model was created with the experimental
|
||||
// safetensors builder by checking the model format in the config.
|
||||
func IsSafetensorsModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
|
||||
// (has completion capability, not image generation).
|
||||
func IsSafetensorsLLMModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
|
||||
}
|
||||
|
||||
// IsImageGenModel checks if a model is an image generation model
|
||||
// (has image capability).
|
||||
func IsImageGenModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
|
||||
}
|
||||
|
||||
// GetModelArchitecture returns the architecture from the model's config.json layer.
|
||||
func GetModelArchitecture(modelName string) (string, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find the config.json layer
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
|
||||
blobName := strings.Replace(layer.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prefer model_type, fall back to first architecture
|
||||
if cfg.ModelType != "" {
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
if len(cfg.Architectures) > 0 {
|
||||
return cfg.Architectures[0], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("architecture not found in model config")
|
||||
}
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
|
||||
// by looking for config.json and at least one .safetensors file.
|
||||
func IsSafetensorsModelDir(dir string) bool {
|
||||
// Must have config.json
|
||||
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must have at least one .safetensors file
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
|
||||
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
// Image gen specific: skip VAE entirely
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip layer norms and RMS norms
|
||||
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip biases
|
||||
if strings.HasSuffix(name, ".bias") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize weights
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return false
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size (32)
|
||||
if shape[len(shape)-1]%32 != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
// Process all safetensors files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(modelDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
// Process all JSON config files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the index file as we don't need it after extraction
|
||||
if entry.Name() == "model.safetensors.index.json" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfgPath := entry.Name()
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use config.json as the config layer
|
||||
if cfgPath == "config.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
752
x/create/create_test.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsTensorModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid diffusers model with model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "directory with other files but no model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsTensorModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid safetensors model with config.json and .safetensors file",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "config.json only, no safetensors files",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "safetensors file only, no config.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple safetensors files with config.json",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsSafetensorsModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
|
||||
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
|
||||
if got != false {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
|
||||
func createMinimalSafetensors(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
|
||||
// Create a minimal safetensors file with a single float32 tensor
|
||||
header := map[string]interface{}{
|
||||
"test_tensor": map[string]interface{}{
|
||||
"dtype": "F32",
|
||||
"shape": []int{2, 2},
|
||||
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
|
||||
},
|
||||
}
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Write file
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Write header size (8 bytes, little endian)
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
|
||||
// Write header
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
t.Fatalf("failed to write header: %v", err)
|
||||
}
|
||||
|
||||
// Write tensor data (16 bytes of zeros for 4 float32 values)
|
||||
if _, err := f.Write(make([]byte, 16)); err != nil {
|
||||
t.Fatalf("failed to write tensor data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Track what was created
|
||||
var createdLayers []LayerInfo
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var manifestConfigLayer LayerInfo
|
||||
var manifestLayers []LayerInfo
|
||||
var statusMessages []string
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:test",
|
||||
Size: int64(len(data)),
|
||||
MediaType: mediaType,
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:tensor_" + name,
|
||||
Size: int64(len(data)),
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return []LayerInfo{layer}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
manifestConfigLayer = config
|
||||
manifestLayers = layers
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
// Run CreateSafetensorsModel
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify manifest was written
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-model" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
|
||||
}
|
||||
|
||||
// Verify config layer was set
|
||||
if manifestConfigLayer.Name != "config.json" {
|
||||
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
|
||||
}
|
||||
|
||||
// Verify we have at least one tensor and one config layer
|
||||
hasTensor := false
|
||||
hasConfig := false
|
||||
for _, layer := range manifestLayers {
|
||||
if layer.Name == "test_tensor" {
|
||||
hasTensor = true
|
||||
}
|
||||
if layer.Name == "config.json" {
|
||||
hasConfig = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTensor {
|
||||
t.Error("no tensor layer found in manifest")
|
||||
}
|
||||
if !hasConfig {
|
||||
t.Error("no config layer found in manifest")
|
||||
}
|
||||
|
||||
// Verify status messages were sent
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only a safetensors file, no config.json
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Mock callbacks (minimal)
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing config.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
return LayerInfo{}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
return []LayerInfo{{}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create model.safetensors.index.json (should be skipped)
|
||||
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var configNames []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
configNames = append(configNames, name)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify model.safetensors.index.json was not included
|
||||
for _, name := range configNames {
|
||||
if name == "model.safetensors.index.json" {
|
||||
t.Error("model.safetensors.index.json should have been skipped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
wantParts []string // Parts that should appear in the path
|
||||
}{
|
||||
{
|
||||
name: "simple model name",
|
||||
modelName: "llama2",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with tag",
|
||||
modelName: "llama2:7b",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace",
|
||||
modelName: "myuser/mymodel",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace and tag",
|
||||
modelName: "myuser/mymodel:v1",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
|
||||
},
|
||||
{
|
||||
name: "fully qualified model name",
|
||||
modelName: "registry.example.com/namespace/model:tag",
|
||||
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := resolveManifestPath(tt.modelName)
|
||||
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(got, part) {
|
||||
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerInfo(t *testing.T) {
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:abc123",
|
||||
Size: 1024,
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: "model.weight",
|
||||
}
|
||||
|
||||
if layer.Digest != "sha256:abc123" {
|
||||
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
|
||||
}
|
||||
if layer.Size != 1024 {
|
||||
t.Errorf("Size = %d, want %d", layer.Size, 1024)
|
||||
}
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
|
||||
}
|
||||
if layer.Name != "model.weight" {
|
||||
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig(t *testing.T) {
|
||||
config := ModelConfig{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion", "chat"},
|
||||
}
|
||||
|
||||
if config.ModelFormat != "safetensors" {
|
||||
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
|
||||
}
|
||||
if len(config.Capabilities) != 2 {
|
||||
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifest(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
Config: ManifestLayer{
|
||||
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
},
|
||||
Layers: []ManifestLayer{
|
||||
{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: "sha256:layer1",
|
||||
Size: 1000,
|
||||
Name: "weight.bin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if manifest.SchemaVersion != 2 {
|
||||
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
|
||||
}
|
||||
if manifest.Config.Digest != "sha256:config" {
|
||||
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
|
||||
}
|
||||
if len(manifest.Layers) != 1 {
|
||||
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
|
||||
}
|
||||
if manifest.Layers[0].Name != "weight.bin" {
|
||||
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
component string
|
||||
want bool
|
||||
}{
|
||||
// VAE component should never be quantized
|
||||
{"vae weight", "decoder.weight", "vae", false},
|
||||
{"vae bias", "decoder.bias", "vae", false},
|
||||
|
||||
// Embeddings should not be quantized
|
||||
{"embedding weight", "embed_tokens.weight", "", false},
|
||||
{"embedding in name", "token_embedding.weight", "", false},
|
||||
|
||||
// Norms should not be quantized
|
||||
{"layer norm", "layer_norm.weight", "", false},
|
||||
{"rms norm", "rms_norm.weight", "", false},
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
|
||||
// Linear weights should be quantized
|
||||
{"linear weight", "q_proj.weight", "", true},
|
||||
{"attention weight", "self_attn.weight", "", true},
|
||||
{"mlp weight", "mlp.gate_proj.weight", "", true},
|
||||
|
||||
// Transformer component weights should be quantized
|
||||
{"transformer weight", "layers.0.weight", "transformer", true},
|
||||
{"text_encoder weight", "encoder.weight", "text_encoder", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantize(tt.tensor, tt.component)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
// Run with quantize enabled
|
||||
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify quantize was passed to callback (will be false for small test tensor)
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalImageGenModel creates a minimal diffusers-style model directory
|
||||
func createMinimalImageGenModel(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
|
||||
// Create model_index.json
|
||||
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
|
||||
t.Fatalf("failed to write model_index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create transformer directory with a safetensors file
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
// Create transformer config
|
||||
transformerConfig := `{"hidden_size": 3072}`
|
||||
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
|
||||
t.Fatalf("failed to write transformer config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var statusMessages []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-imagegen" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
|
||||
}
|
||||
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only transformer without model_index.json
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing model_index.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -12,38 +12,24 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: fp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
var torchDtype string // Read from component config for quantization display
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
@@ -74,7 +60,11 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
|
||||
quantizeMsg := ""
|
||||
if quantize != "" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to " + quantize
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
@@ -83,22 +73,52 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Count parameters from original tensor shape
|
||||
if len(td.Shape) > 0 {
|
||||
numElements := int64(1)
|
||||
for _, dim := range td.Shape {
|
||||
numElements *= int64(dim)
|
||||
}
|
||||
totalParams += numElements
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Read torch_dtype from text_encoder config for quantization display
|
||||
if torchDtype == "" {
|
||||
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
|
||||
if data, err := os.ReadFile(textEncoderConfig); err == nil {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
|
||||
torchDtype = cfg.TorchDtype
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Import config files
|
||||
configFiles := []string{
|
||||
"model_index.json",
|
||||
@@ -122,7 +142,7 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
|
||||
var r io.Reader
|
||||
|
||||
// For model_index.json, normalize to Ollama format
|
||||
// For model_index.json, normalize to Ollama format and add metadata
|
||||
if cfgPath == "model_index.json" {
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
@@ -141,6 +161,16 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
}
|
||||
delete(cfg, "_diffusers_version")
|
||||
|
||||
// Add parameter count (counted from tensor shapes during import)
|
||||
cfg["parameter_count"] = totalParams
|
||||
|
||||
// Add quantization info - use quantize type if set, otherwise torch_dtype
|
||||
if quantize != "" {
|
||||
cfg["quantization"] = strings.ToUpper(quantize)
|
||||
} else {
|
||||
cfg["quantization"] = torchDtype
|
||||
}
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
||||
@@ -181,3 +211,12 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size (32).
|
||||
func canQuantizeShape(shape []int32) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
return shape[len(shape)-1]%32 == 0
|
||||
}
|
||||
@@ -234,3 +234,17 @@ ollama create z-image
|
||||
3. Copy config files (*.json) as config layers
|
||||
4. Write manifest
|
||||
```
|
||||
|
||||
## FP8 Quantization
|
||||
|
||||
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
cd ./weights/Z-Image-Turbo
|
||||
ollama create z-image-fp8 --quantize fp8
|
||||
```
|
||||
|
||||
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.
|
||||
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imagePath string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imagePath = extractPath(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imagePath, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imagePath string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imagePath = extractPath(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imagePath, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractPath(content string) string {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
return strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imagePath, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imagePath == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
resp.Data[0].URL = "file://" + imagePath
|
||||
} else {
|
||||
data, err := os.ReadFile(imagePath)
|
||||
if err == nil {
|
||||
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
197
x/imagegen/cache/teacache.go
vendored
Normal file
@@ -0,0 +1,197 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package cache provides caching mechanisms for diffusion model inference.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
|
||||
// It caches the transformer output and reuses it when timestep values
|
||||
// are similar between consecutive steps.
|
||||
//
|
||||
// For CFG (classifier-free guidance), it caches pos and neg predictions
|
||||
// separately and always computes CFG fresh to avoid error amplification.
|
||||
//
|
||||
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
|
||||
// https://github.com/ali-vilab/TeaCache
|
||||
type TeaCache struct {
|
||||
// Cached transformer output from last computed step (non-CFG mode)
|
||||
cachedOutput *mlx.Array
|
||||
|
||||
// Cached CFG outputs (pos and neg separately)
|
||||
cachedPosOutput *mlx.Array
|
||||
cachedNegOutput *mlx.Array
|
||||
|
||||
// Previous timestep value for difference calculation
|
||||
prevTimestep float32
|
||||
|
||||
// Accumulated difference for rescaling
|
||||
accumulatedDiff float32
|
||||
|
||||
// Configuration
|
||||
threshold float32 // Threshold for recomputation decision
|
||||
rescaleFactor float32 // Model-specific rescaling factor
|
||||
skipEarlySteps int // Number of early steps to never cache
|
||||
|
||||
// Statistics
|
||||
cacheHits int
|
||||
cacheMisses int
|
||||
}
|
||||
|
||||
// TeaCacheConfig holds configuration for TeaCache.
|
||||
type TeaCacheConfig struct {
|
||||
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
|
||||
// Recommended: 0.05-0.15 for image models
|
||||
Threshold float32
|
||||
|
||||
// Rescale factor to adjust timestep embedding differences.
|
||||
// Model-specific, typically 1.0-2.0
|
||||
RescaleFactor float32
|
||||
|
||||
// SkipEarlySteps: number of early steps to always compute (never cache).
|
||||
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
|
||||
SkipEarlySteps int
|
||||
}
|
||||
|
||||
// DefaultTeaCacheConfig returns default configuration for TeaCache.
|
||||
func DefaultTeaCacheConfig() *TeaCacheConfig {
|
||||
return &TeaCacheConfig{
|
||||
Threshold: 0.1,
|
||||
RescaleFactor: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTeaCache creates a new TeaCache instance.
|
||||
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
|
||||
if cfg == nil {
|
||||
cfg = DefaultTeaCacheConfig()
|
||||
}
|
||||
return &TeaCache{
|
||||
threshold: cfg.Threshold,
|
||||
rescaleFactor: cfg.RescaleFactor,
|
||||
skipEarlySteps: cfg.SkipEarlySteps,
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldCompute determines if we should compute the full forward pass
|
||||
// or reuse the cached output based on timestep similarity.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. First step always computes
|
||||
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
|
||||
// 3. If accumulated difference > threshold, compute new output
|
||||
// 4. Otherwise, reuse cached output
|
||||
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
|
||||
// Always compute early steps (critical for structure)
|
||||
// Check both regular cache and CFG cache
|
||||
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
|
||||
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
|
||||
return true
|
||||
}
|
||||
|
||||
// Compute absolute difference between current and previous timestep
|
||||
diff := timestep - tc.prevTimestep
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
|
||||
// Apply rescaling factor
|
||||
scaledDiff := diff * tc.rescaleFactor
|
||||
|
||||
// Accumulate difference (helps track drift over multiple cached steps)
|
||||
tc.accumulatedDiff += scaledDiff
|
||||
|
||||
// Decision based on accumulated difference
|
||||
if tc.accumulatedDiff > tc.threshold {
|
||||
tc.accumulatedDiff = 0 // Reset accumulator
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
|
||||
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
|
||||
// Free previous cached output
|
||||
if tc.cachedOutput != nil {
|
||||
tc.cachedOutput.Free()
|
||||
}
|
||||
|
||||
// Store new cached values
|
||||
tc.cachedOutput = output
|
||||
tc.prevTimestep = timestep
|
||||
tc.cacheMisses++
|
||||
}
|
||||
|
||||
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
|
||||
// This allows CFG to be computed fresh each step, avoiding error amplification.
|
||||
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
|
||||
// Free previous cached outputs
|
||||
if tc.cachedPosOutput != nil {
|
||||
tc.cachedPosOutput.Free()
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
tc.cachedNegOutput.Free()
|
||||
}
|
||||
|
||||
// Store new cached values
|
||||
tc.cachedPosOutput = posOutput
|
||||
tc.cachedNegOutput = negOutput
|
||||
tc.prevTimestep = timestep
|
||||
tc.cacheMisses++
|
||||
}
|
||||
|
||||
// GetCached returns the cached output (non-CFG mode).
|
||||
func (tc *TeaCache) GetCached() *mlx.Array {
|
||||
tc.cacheHits++
|
||||
return tc.cachedOutput
|
||||
}
|
||||
|
||||
// GetCFGCached returns cached pos and neg outputs for CFG mode.
|
||||
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
|
||||
tc.cacheHits++
|
||||
return tc.cachedPosOutput, tc.cachedNegOutput
|
||||
}
|
||||
|
||||
// HasCFGCache returns true if CFG cache is available.
|
||||
func (tc *TeaCache) HasCFGCache() bool {
|
||||
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
|
||||
}
|
||||
|
||||
// Arrays returns all arrays that should be kept alive.
|
||||
func (tc *TeaCache) Arrays() []*mlx.Array {
|
||||
var arrays []*mlx.Array
|
||||
if tc.cachedOutput != nil {
|
||||
arrays = append(arrays, tc.cachedOutput)
|
||||
}
|
||||
if tc.cachedPosOutput != nil {
|
||||
arrays = append(arrays, tc.cachedPosOutput)
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
arrays = append(arrays, tc.cachedNegOutput)
|
||||
}
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Stats returns cache hit/miss statistics.
|
||||
func (tc *TeaCache) Stats() (hits, misses int) {
|
||||
return tc.cacheHits, tc.cacheMisses
|
||||
}
|
||||
|
||||
// Free releases all cached arrays.
|
||||
func (tc *TeaCache) Free() {
|
||||
if tc.cachedOutput != nil {
|
||||
tc.cachedOutput.Free()
|
||||
tc.cachedOutput = nil
|
||||
}
|
||||
if tc.cachedPosOutput != nil {
|
||||
tc.cachedPosOutput.Free()
|
||||
tc.cachedPosOutput = nil
|
||||
}
|
||||
if tc.cachedNegOutput != nil {
|
||||
tc.cachedNegOutput.Free()
|
||||
tc.cachedNegOutput = nil
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -39,75 +38,17 @@ func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Steps: 0, // 0 means model default
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// Show displays information about an image generation model.
|
||||
func Show(modelName string, w io.Writer) error {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Count total size
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
|
||||
// Read model_index.json for architecture
|
||||
var architecture string
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
architecture = index.Architecture
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
|
||||
paramCount := totalSize / 2
|
||||
paramStr := formatParamCount(paramCount)
|
||||
|
||||
// Print Model info
|
||||
fmt.Fprintln(w, " Model")
|
||||
if architecture != "" {
|
||||
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
|
||||
}
|
||||
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
|
||||
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
|
||||
fmt.Fprintln(w)
|
||||
|
||||
// Print Capabilities
|
||||
fmt.Fprintln(w, " Capabilities")
|
||||
fmt.Fprintf(w, " %s\n", "image")
|
||||
fmt.Fprintln(w)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatParamCount formats parameter count as human-readable string.
|
||||
func formatParamCount(count int64) string {
|
||||
if count >= 1_000_000_000 {
|
||||
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
|
||||
}
|
||||
if count >= 1_000_000 {
|
||||
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
|
||||
}
|
||||
return fmt.Sprintf("%d", count)
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
cmd.Flags().MarkHidden("width")
|
||||
@@ -121,11 +62,6 @@ func RegisterFlags(cmd *cobra.Command) {
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||
// Verify it's a valid image gen model
|
||||
if ResolveModelName(name) == "" {
|
||||
return fmt.Errorf("unknown image generation model: %s", name)
|
||||
}
|
||||
|
||||
// Get options from flags (with env var defaults)
|
||||
opts := DefaultOptions()
|
||||
if cmd != nil && cmd.Flags() != nil {
|
||||
@@ -161,17 +97,12 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
req := &api.XGenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -183,31 +114,21 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imagePath string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
var imageBase64 string
|
||||
err = client.XGenerate(cmd.Context(), req, func(resp api.XGenerateResponse) error {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 && resp.Completed > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with image path
|
||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
imagePath = strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -218,9 +139,27 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
if imagePath != "" {
|
||||
displayImageInTerminal(imagePath)
|
||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
||||
if imageBase64 != "" {
|
||||
// Decode base64 and save to CWD
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
// Create filename from prompt
|
||||
safeName := sanitizeFilename(prompt)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to save image: %w", err)
|
||||
}
|
||||
|
||||
displayImageInTerminal(filename)
|
||||
fmt.Printf("Image saved to: %s\n", filename)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -286,15 +225,12 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
}
|
||||
|
||||
// Generate image with current options
|
||||
req := &api.GenerateRequest{
|
||||
req := &api.XGenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -306,31 +242,22 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
p.Add("", spinner)
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imagePath string
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
err = client.XGenerate(cmd.Context(), req, func(resp api.XGenerateResponse) error {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 && resp.Completed > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with image path
|
||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
||||
imagePath = strings.TrimSpace(content[idx+16:])
|
||||
}
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -342,25 +269,30 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
continue
|
||||
}
|
||||
|
||||
// Copy image to current directory with descriptive name
|
||||
if imagePath != "" {
|
||||
// Save image to current directory with descriptive name
|
||||
if imageBase64 != "" {
|
||||
// Decode base64 image data
|
||||
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create filename from prompt (sanitized)
|
||||
safeName := sanitizeFilename(line)
|
||||
if len(safeName) > 50 {
|
||||
safeName = safeName[:50]
|
||||
}
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||
|
||||
// Copy file to CWD
|
||||
if err := copyFile(imagePath, newName); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
|
||||
displayImageInTerminal(imagePath)
|
||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
||||
} else {
|
||||
displayImageInTerminal(newName)
|
||||
fmt.Printf("Image saved to: %s\n", newName)
|
||||
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
displayImageInTerminal(filename)
|
||||
fmt.Printf("Image saved to: %s\n", filename)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
@@ -381,24 +313,6 @@ func sanitizeFilename(s string) string {
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// copyFile copies a file from src to dst.
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, sourceFile)
|
||||
return err
|
||||
}
|
||||
|
||||
// printInteractiveHelp prints help for interactive mode commands.
|
||||
func printInteractiveHelp(opts ImageGenOptions) {
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
@@ -509,10 +423,7 @@ func displayImageInTerminal(imagePath string) bool {
|
||||
// Send in chunks for large images
|
||||
const chunkSize = 4096
|
||||
for i := 0; i < len(encoded); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(encoded) {
|
||||
end = len(encoded)
|
||||
}
|
||||
end := min(i+chunkSize, len(encoded))
|
||||
chunk := encoded[i:end]
|
||||
|
||||
if i == 0 {
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
// Package client provides client-side model creation for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
//
|
||||
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
|
||||
// 1. Add proper API endpoints for tensor model creation
|
||||
// 2. Move tensor extraction to server-side
|
||||
// 3. Remove this package
|
||||
// 4. Follow the same client→server pattern as regular model creation
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for image generation models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
|
||||
status := "importing image generation model"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
|
||||
// Create layer callback for config files
|
||||
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create manifest writer callback
|
||||
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create a proper config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"image"},
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
status = msg
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
}
|
||||
|
||||
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created image generation model '%s'\n", modelName)
|
||||
return nil
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
@@ -48,7 +49,7 @@ func main() {
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
@@ -67,6 +68,9 @@ func main() {
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
|
||||
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
|
||||
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@@ -99,13 +103,17 @@ func main() {
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
LayerCache: *layerCache,
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
TeaCache: *teaCache,
|
||||
TeaCacheThreshold: float32(*teaCacheThreshold),
|
||||
FusedQKV: *fusedQKV,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
|
||||
@@ -60,9 +60,12 @@ func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
||||
}
|
||||
|
||||
// Transform to [H, W, C] for image conversion
|
||||
img := mlx.Squeeze(arr, 0)
|
||||
img = mlx.Transpose(img, 1, 2, 0)
|
||||
img = mlx.Contiguous(img)
|
||||
// Free intermediate arrays to avoid memory leak
|
||||
squeezed := mlx.Squeeze(arr, 0)
|
||||
transposed := mlx.Transpose(squeezed, 1, 2, 0)
|
||||
squeezed.Free()
|
||||
img := mlx.Contiguous(transposed)
|
||||
transposed.Free()
|
||||
mlx.Eval(img)
|
||||
|
||||
imgShape := img.Shape()
|
||||
|
||||
@@ -175,3 +175,63 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
@@ -95,8 +95,3 @@ func EstimateVRAM(modelName string) uint64 {
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
}
|
||||
|
||||
@@ -94,13 +94,6 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -607,6 +607,11 @@ func (a *Array) Valid() bool {
|
||||
return a != nil && a.c.ctx != nil
|
||||
}
|
||||
|
||||
// Kept returns true if the array is marked to survive Eval() cleanup.
|
||||
func (a *Array) Kept() bool {
|
||||
return a != nil && a.kept
|
||||
}
|
||||
|
||||
func int32ToCInt(s []int32) *C.int {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
@@ -1480,6 +1485,44 @@ func (a *Array) ItemInt32() int32 {
|
||||
return int32(val)
|
||||
}
|
||||
|
||||
// Bytes copies the raw bytes out of the array without type conversion.
|
||||
// Works with common dtypes (float32, int32, uint32, uint8).
|
||||
// For non-contiguous arrays, call Contiguous() first.
|
||||
// Note: Triggers cleanup of non-kept arrays.
|
||||
func (a *Array) Bytes() []byte {
|
||||
cleanup()
|
||||
nbytes := a.Nbytes()
|
||||
if nbytes == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get raw pointer based on dtype
|
||||
var ptr unsafe.Pointer
|
||||
switch a.Dtype() {
|
||||
case DtypeFloat32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
|
||||
case DtypeInt32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
|
||||
case DtypeUint32:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
|
||||
case DtypeUint8:
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
|
||||
default:
|
||||
// For other types (bf16, f16, etc), convert to float32
|
||||
arr := AsType(a, DtypeFloat32)
|
||||
arr.Eval()
|
||||
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
|
||||
nbytes = arr.Nbytes()
|
||||
}
|
||||
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
data := make([]byte, nbytes)
|
||||
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
|
||||
return data
|
||||
}
|
||||
|
||||
// ============ Utility ============
|
||||
|
||||
// String returns a string representation
|
||||
@@ -1658,6 +1701,34 @@ func (s *SafetensorsFile) Free() {
|
||||
C.mlx_map_string_to_string_free(s.metadata)
|
||||
}
|
||||
|
||||
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
|
||||
// This correctly handles all dtypes including uint32 for quantized weights.
|
||||
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
// Create the map
|
||||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
// Add each array to the map
|
||||
for name, arr := range arrays {
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
// Create empty metadata (optional)
|
||||
cMeta := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMeta)
|
||||
|
||||
// Save
|
||||
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
|
||||
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============ NPY Loading ============
|
||||
|
||||
// LoadNpy loads a numpy array from an npy file
|
||||
@@ -1986,7 +2057,8 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
|
||||
// Returns (quantized_weights, scales, biases).
|
||||
// groupSize: number of elements quantized together (default 64)
|
||||
// bits: bits per element, 2, 4, or 8 (default 4)
|
||||
// mode: "affine" (default) or "mxfp4"
|
||||
// mode: "affine" (default), "mxfp4", or "mxfp8"
|
||||
// Note: mxfp8 mode returns nil biases (only weights and scales)
|
||||
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
@@ -1995,14 +2067,21 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
res := C.mlx_vector_array_new()
|
||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
||||
|
||||
// Result is a vector of 3 arrays: [weights, scales, biases]
|
||||
// Result is a vector of arrays: [weights, scales, biases?]
|
||||
// mxfp8 mode returns only 2 elements (no biases)
|
||||
vecSize := int(C.mlx_vector_array_size(res))
|
||||
var w0, w1, w2 C.mlx_array
|
||||
C.mlx_vector_array_get(&w0, res, 0)
|
||||
C.mlx_vector_array_get(&w1, res, 1)
|
||||
C.mlx_vector_array_get(&w2, res, 2)
|
||||
if vecSize >= 3 {
|
||||
C.mlx_vector_array_get(&w2, res, 2)
|
||||
}
|
||||
C.mlx_vector_array_free(res)
|
||||
|
||||
return newArray(w0), newArray(w1), newArray(w2)
|
||||
if vecSize >= 3 {
|
||||
return newArray(w0), newArray(w1), newArray(w2)
|
||||
}
|
||||
return newArray(w0), newArray(w1), nil
|
||||
}
|
||||
|
||||
// Dequantize reconstructs weights from quantized form.
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -172,7 +173,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 30
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
@@ -222,6 +223,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
@@ -264,10 +273,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// True CFG: run twice and combine with norm rescaling
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
L := batchedOutput.Shape()[1]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
@@ -305,6 +323,9 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
|
||||
@@ -241,6 +241,14 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// Encode all input images to latents and concatenate
|
||||
fmt.Println("Encoding images to latents...")
|
||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||
@@ -291,11 +299,18 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Tile inputs: [1, L, D] -> [2, L, D]
|
||||
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
|
||||
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
|
||||
// Single batched forward pass
|
||||
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
D := batchedOutput.Shape()[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
|
||||
|
||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||
} else {
|
||||
@@ -317,6 +332,9 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
imageLatentsPacked.Free()
|
||||
|
||||
@@ -28,12 +28,12 @@ type Qwen3Config struct {
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OProj *nn.Linear `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
@@ -136,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj *nn.Linear `weight:"gate_proj"`
|
||||
UpProj *nn.Linear `weight:"up_proj"`
|
||||
DownProj *nn.Linear `weight:"down_proj"`
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
|
||||
@@ -36,8 +36,8 @@ type TransformerConfig struct {
|
||||
// TimestepEmbedder creates sinusoidal timestep embeddings
|
||||
// Output dimension is 256 (fixed), used for AdaLN modulation
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 *nn.Linear `weight:"mlp.0"`
|
||||
Linear2 *nn.Linear `weight:"mlp.2"`
|
||||
Linear1 nn.LinearLayer `weight:"mlp.0"`
|
||||
Linear2 nn.LinearLayer `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
|
||||
// XEmbedder embeds image patches to model dimension
|
||||
type XEmbedder struct {
|
||||
Linear *nn.Linear `weight:"2-1"`
|
||||
Linear nn.LinearLayer `weight:"2-1"`
|
||||
}
|
||||
|
||||
// Forward embeds patchified image latents
|
||||
@@ -86,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear *nn.Linear `weight:"1"`
|
||||
Linear nn.LinearLayer `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
@@ -100,12 +100,13 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
|
||||
|
||||
// FeedForward implements SwiGLU FFN
|
||||
type FeedForward struct {
|
||||
W1 *nn.Linear `weight:"w1"` // gate projection
|
||||
W2 *nn.Linear `weight:"w2"` // down projection
|
||||
W3 *nn.Linear `weight:"w3"` // up projection
|
||||
W1 nn.LinearLayer `weight:"w1"` // gate projection
|
||||
W2 nn.LinearLayer `weight:"w2"` // down projection
|
||||
W3 nn.LinearLayer `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
@@ -115,6 +116,7 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Reshape for matmul
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
|
||||
gate := ff.W1.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := ff.W3.Forward(x)
|
||||
@@ -126,17 +128,69 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
|
||||
// Attention implements multi-head attention with QK norm
|
||||
type Attention struct {
|
||||
ToQ *nn.Linear `weight:"to_q"`
|
||||
ToK *nn.Linear `weight:"to_k"`
|
||||
ToV *nn.Linear `weight:"to_v"`
|
||||
ToOut *nn.Linear `weight:"to_out.0"`
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Dim int32
|
||||
Scale float32
|
||||
// Fused QKV (computed at init time for efficiency, not loaded from weights)
|
||||
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
|
||||
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||
// Computed fields (not loaded from weights)
|
||||
NHeads int32 `weight:"-"`
|
||||
HeadDim int32 `weight:"-"`
|
||||
Dim int32 `weight:"-"`
|
||||
Scale float32 `weight:"-"`
|
||||
}
|
||||
|
||||
// FuseQKV creates a fused QKV projection by concatenating weights.
|
||||
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
|
||||
// Note: Fusion is skipped for quantized weights as it would require complex
|
||||
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
|
||||
// the ~5% fusion benefit.
|
||||
func (attn *Attention) FuseQKV() {
|
||||
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip fusion for quantized weights - type assert to check
|
||||
toQ, qOk := attn.ToQ.(*nn.Linear)
|
||||
toK, kOk := attn.ToK.(*nn.Linear)
|
||||
toV, vOk := attn.ToV.(*nn.Linear)
|
||||
if !qOk || !kOk || !vOk {
|
||||
// One or more are QuantizedLinear, skip fusion
|
||||
return
|
||||
}
|
||||
|
||||
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
|
||||
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
|
||||
qWeight := toQ.Weight
|
||||
kWeight := toK.Weight
|
||||
vWeight := toV.Weight
|
||||
|
||||
// Concatenate along output dimension (axis 0)
|
||||
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
|
||||
|
||||
// Evaluate fused weight to ensure it's materialized
|
||||
mlx.Eval(fusedWeight)
|
||||
|
||||
// Create fused linear layer
|
||||
fusedLinear := &nn.Linear{Weight: fusedWeight}
|
||||
|
||||
// Handle bias if present
|
||||
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
|
||||
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
|
||||
mlx.Eval(fusedBias)
|
||||
fusedLinear.Bias = fusedBias
|
||||
}
|
||||
|
||||
attn.ToQKV = fusedLinear
|
||||
attn.Fused = true
|
||||
}
|
||||
|
||||
// Forward computes attention
|
||||
@@ -146,11 +200,24 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Project Q, K, V
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
q := attn.ToQ.Forward(xFlat)
|
||||
k := attn.ToK.Forward(xFlat)
|
||||
v := attn.ToV.Forward(xFlat)
|
||||
|
||||
var q, k, v *mlx.Array
|
||||
if attn.Fused && attn.ToQKV != nil {
|
||||
// Fused QKV path: single matmul then split
|
||||
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
|
||||
|
||||
// Split into Q, K, V along last dimension
|
||||
// Each has shape [B*L, dim]
|
||||
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
|
||||
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
|
||||
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
|
||||
} else {
|
||||
// Separate Q, K, V projections
|
||||
q = attn.ToQ.Forward(xFlat)
|
||||
k = attn.ToK.Forward(xFlat)
|
||||
v = attn.ToV.Forward(xFlat)
|
||||
}
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
@@ -227,7 +294,7 @@ type TransformerBlock struct {
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
@@ -281,8 +348,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
|
||||
|
||||
// FinalLayer outputs the denoised patches
|
||||
type FinalLayer struct {
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
|
||||
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
@@ -350,12 +417,11 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
// Load weights from tensor blobs with BF16 conversion
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
@@ -377,7 +443,7 @@ func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
|
||||
func (m *Transformer) initComputedFields() {
|
||||
cfg := m.TransformerConfig
|
||||
m.TEmbed.FreqEmbedSize = 256
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
|
||||
m.CapEmbed.Norm.Eps = 1e-6
|
||||
|
||||
for _, block := range m.NoiseRefiners {
|
||||
@@ -391,6 +457,20 @@ func (m *Transformer) initComputedFields() {
|
||||
}
|
||||
}
|
||||
|
||||
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
|
||||
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
|
||||
func (m *Transformer) FuseAllQKV() {
|
||||
for _, block := range m.NoiseRefiners {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
for _, block := range m.ContextRefiners {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
for _, block := range m.Layers {
|
||||
block.Attention.FuseQKV()
|
||||
}
|
||||
}
|
||||
|
||||
// initTransformerBlock sets computed fields on a transformer block
|
||||
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||
block.Dim = cfg.Dim
|
||||
@@ -404,7 +484,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
|
||||
|
||||
// Init feedforward OutDim
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
|
||||
|
||||
// Set eps on all RMSNorm layers
|
||||
block.AttentionNorm1.Eps = cfg.NormEps
|
||||
@@ -423,6 +503,8 @@ type RoPECache struct {
|
||||
UnifiedSin *mlx.Array
|
||||
ImgLen int32
|
||||
CapLen int32
|
||||
GridH int32 // Image token grid height
|
||||
GridW int32 // Image token grid width
|
||||
}
|
||||
|
||||
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
|
||||
@@ -456,6 +538,8 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
|
||||
UnifiedSin: unifiedSin,
|
||||
ImgLen: imgLen,
|
||||
CapLen: capLen,
|
||||
GridH: hTok,
|
||||
GridW: wTok,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -104,6 +104,8 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
|
||||
groupSize := C / gn.NumGroups
|
||||
|
||||
// Keep the input - we need it for slicing tiles later
|
||||
// Track if we were the ones who kept it, so we can restore state after
|
||||
wasKept := x.Kept()
|
||||
mlx.Keep(x)
|
||||
|
||||
// Compute per-group mean and variance using flattened spatial dimensions
|
||||
@@ -205,6 +207,10 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
|
||||
}
|
||||
|
||||
// Clean up kept arrays
|
||||
// Restore x's kept state - only free if we were the ones who kept it
|
||||
if !wasKept {
|
||||
x.Free()
|
||||
}
|
||||
mean.Free()
|
||||
invStd.Free()
|
||||
if weightGN != nil {
|
||||
@@ -734,18 +740,26 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
h := vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
|
||||
prev := h
|
||||
h = vae.MidBlock.Forward(h)
|
||||
prev.Free()
|
||||
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
prev = h
|
||||
h = upBlock.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
prev := h
|
||||
prev = h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
||||
prev.Free()
|
||||
|
||||
prev = h
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
mlx.Eval(h)
|
||||
prev.Free()
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
@@ -754,7 +768,6 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
|
||||
return h
|
||||
|
||||
@@ -26,10 +26,12 @@ type GenerateConfig struct {
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// Layer caching options (speedup via shallow layer reuse)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 15)
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
|
||||
|
||||
// Fused QKV (fuse Q/K/V projections into single matmul)
|
||||
FusedQKV bool // Enable fused QKV projection (default: false)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
@@ -42,6 +44,7 @@ type Model struct {
|
||||
TextEncoder *Qwen3TextEncoder
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
qkvFused bool // Track if QKV has been fused (do only once)
|
||||
}
|
||||
|
||||
// Load loads the Z-Image model from ollama blob storage.
|
||||
@@ -191,18 +194,22 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Turbo default
|
||||
cfg.Steps = 9 // Z-Image turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.LayerCache {
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 15 // Half of 30 layers
|
||||
}
|
||||
// TeaCache enabled by default
|
||||
cfg.TeaCache = true
|
||||
if cfg.TeaCacheThreshold <= 0 {
|
||||
cfg.TeaCacheThreshold = 0.15
|
||||
}
|
||||
|
||||
// Enable fused QKV if requested (only fuse once)
|
||||
if cfg.FusedQKV && !m.qkvFused {
|
||||
m.Transformer.FuseAllQKV()
|
||||
m.qkvFused = true
|
||||
fmt.Println(" Fused QKV enabled")
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
@@ -260,12 +267,54 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
mlx.Eval(ropeCache.UnifiedCos)
|
||||
}
|
||||
|
||||
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
|
||||
cfg.CacheLayers, cfg.CacheInterval)
|
||||
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
|
||||
var batchedEmb *mlx.Array
|
||||
if useCFG {
|
||||
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
|
||||
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||
mlx.Keep(batchedEmb)
|
||||
mlx.Eval(batchedEmb)
|
||||
}
|
||||
|
||||
// TeaCache for timestep-aware caching
|
||||
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
|
||||
var teaCache *cache.TeaCache
|
||||
if cfg.TeaCache {
|
||||
skipEarly := 0
|
||||
if useCFG {
|
||||
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
|
||||
}
|
||||
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
|
||||
Threshold: cfg.TeaCacheThreshold,
|
||||
RescaleFactor: 1.0,
|
||||
SkipEarlySteps: skipEarly,
|
||||
})
|
||||
if useCFG {
|
||||
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
|
||||
} else {
|
||||
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup frees all kept arrays when we need to abort early
|
||||
cleanup := func() {
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgCos.Free()
|
||||
ropeCache.ImgSin.Free()
|
||||
ropeCache.CapCos.Free()
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
if teaCache != nil {
|
||||
teaCache.Free()
|
||||
}
|
||||
latents.Free()
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
@@ -277,6 +326,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cleanup()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
@@ -289,50 +339,77 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
}
|
||||
|
||||
tCurr := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
var noisePred *mlx.Array
|
||||
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
// TeaCache: check if we should compute or reuse cached output
|
||||
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
|
||||
|
||||
var output *mlx.Array
|
||||
if stepCache != nil {
|
||||
// Use layer caching for faster inference
|
||||
if shouldCompute {
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
// Note: CFG with layer cache shares the cache between pos/neg
|
||||
// This is approximate but fast - neg prompt uses same cached shallow layers
|
||||
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
// CFG Batching: single forward pass with batch=2
|
||||
// Tile patches: [1, L, D] -> [2, L, D]
|
||||
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||
// Tile timestep: [1] -> [2]
|
||||
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||
|
||||
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
|
||||
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
|
||||
|
||||
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||
outputShape := batchedOutput.Shape()
|
||||
L := outputShape[1]
|
||||
D := outputShape[2]
|
||||
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||
|
||||
// Convert to noise predictions (unpatchify and negate)
|
||||
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
posPred = mlx.Neg(posPred)
|
||||
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
negPred = mlx.Neg(negPred)
|
||||
|
||||
// Cache pos/neg separately for TeaCache
|
||||
if teaCache != nil {
|
||||
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
|
||||
mlx.Keep(teaCache.Arrays()...)
|
||||
}
|
||||
|
||||
// Apply CFG: noisePred = neg + scale * (pos - neg)
|
||||
diff := mlx.Sub(posPred, negPred)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
}
|
||||
} else {
|
||||
// Standard forward without caching
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
noisePred = mlx.Add(negPred, scaledDiff)
|
||||
} else {
|
||||
// Non-CFG forward pass
|
||||
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
}
|
||||
}
|
||||
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
|
||||
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
// Update TeaCache
|
||||
if teaCache != nil {
|
||||
teaCache.UpdateCache(noisePred, tCurr)
|
||||
mlx.Keep(teaCache.Arrays()...)
|
||||
}
|
||||
}
|
||||
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
|
||||
// CFG mode: get cached pos/neg and compute CFG fresh
|
||||
posPred, negPred := teaCache.GetCFGCached()
|
||||
diff := mlx.Sub(posPred, negPred)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
noisePred = mlx.Add(negPred, scaledDiff)
|
||||
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
|
||||
} else {
|
||||
// Non-CFG mode: reuse cached noise prediction
|
||||
noisePred = teaCache.GetCached()
|
||||
fmt.Printf(" [TeaCache: reusing cached output]\n")
|
||||
}
|
||||
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep latents and any cached arrays
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
@@ -361,8 +438,14 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
if batchedEmb != nil {
|
||||
batchedEmb.Free()
|
||||
}
|
||||
if teaCache != nil {
|
||||
hits, misses := teaCache.Stats()
|
||||
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
|
||||
hits, misses, float64(hits)/float64(hits+misses)*100)
|
||||
teaCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode
|
||||
|
||||
@@ -10,6 +10,13 @@ type Layer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// LinearLayer is an interface for linear layers (both regular and quantized).
|
||||
// This allows swapping between Linear and QuantizedLinear at runtime.
|
||||
type LinearLayer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
OutputDim() int32 // Returns the output dimension of the layer
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
|
||||
type Linear struct {
|
||||
@@ -49,6 +56,11 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Linear(x, w)
|
||||
}
|
||||
|
||||
// OutputDim returns the output dimension of the linear layer.
|
||||
func (l *Linear) OutputDim() int32 {
|
||||
return l.Weight.Shape()[0]
|
||||
}
|
||||
|
||||
// ToQuantized converts this Linear to a QuantizedLinear.
|
||||
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
|
||||
@@ -84,6 +96,13 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return out
|
||||
}
|
||||
|
||||
// OutputDim returns the output dimension of the quantized linear layer.
|
||||
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
|
||||
// The output dimension is the first dimension of the weight.
|
||||
func (ql *QuantizedLinear) OutputDim() int32 {
|
||||
return ql.Weight.Shape()[0]
|
||||
}
|
||||
|
||||
// RMSNorm represents an RMS normalization layer.
|
||||
type RMSNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -34,7 +33,8 @@ type Request struct {
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
type Response struct {
|
||||
Content string `json:"content"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
@@ -136,16 +136,8 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Apply defaults
|
||||
if req.Width <= 0 {
|
||||
req.Width = 1024
|
||||
}
|
||||
if req.Height <= 0 {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
req.Steps = 9
|
||||
}
|
||||
// Model applies its own defaults for width/height/steps
|
||||
// Only seed needs to be set here if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
@@ -191,10 +183,10 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Save image
|
||||
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
|
||||
if err := imagegen.SaveImage(img, outPath); err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
@@ -204,11 +196,12 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
mlx.MetalResetPeakMemory()
|
||||
|
||||
// Send final response
|
||||
// Send final response with image data
|
||||
resp := Response{
|
||||
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
|
||||
Done: true,
|
||||
Image: imageData,
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// WeightSource is an interface for loading weights.
|
||||
@@ -102,6 +103,22 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
}
|
||||
}
|
||||
|
||||
// Handle nn.LinearLayer interface fields specially
|
||||
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
|
||||
if !hasTag {
|
||||
continue // no tag = skip
|
||||
}
|
||||
layer, err := LoadLinearLayer(weights, fullPath)
|
||||
if err != nil {
|
||||
if !optional {
|
||||
*errs = append(*errs, fullPath+": "+err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(layer))
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle by kind
|
||||
switch fieldVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
@@ -176,3 +193,64 @@ func joinPath(prefix, suffix string) string {
|
||||
}
|
||||
return prefix + "." + suffix
|
||||
}
|
||||
|
||||
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
|
||||
// If {path}.weight_scale exists, dequantizes the weights.
|
||||
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
|
||||
// Check if this is a quantized layer by looking for scale tensor
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
|
||||
}
|
||||
|
||||
scales, err := weights.GetTensor(scalePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
|
||||
}
|
||||
|
||||
// Bias is optional
|
||||
var bias *mlx.Array
|
||||
biasPath := path + ".bias"
|
||||
if weights.HasTensor(biasPath) {
|
||||
bias, _ = weights.GetTensor(biasPath)
|
||||
}
|
||||
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: 32,
|
||||
Bits: 8,
|
||||
Mode: "affine",
|
||||
}, nil
|
||||
}
|
||||
|
||||
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
|
||||
return nn.NewLinear(dequantized, bias), nil
|
||||
}
|
||||
|
||||
// Load as regular Linear
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Bias is optional
|
||||
var bias *mlx.Array
|
||||
biasPath := path + ".bias"
|
||||
if weights.HasTensor(biasPath) {
|
||||
bias, _ = weights.GetTensor(biasPath)
|
||||
}
|
||||
|
||||
return nn.NewLinear(weight, bias), nil
|
||||
}
|
||||
|
||||
@@ -14,7 +14,9 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -23,6 +25,11 @@ import (
|
||||
)
|
||||
|
||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||
//
|
||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||
// like any other model. The plan is to eventually bring this into the llm/ package
|
||||
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
|
||||
// separate allows for independent iteration on image generation support.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
@@ -35,21 +42,6 @@ type Server struct {
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// completionRequest is sent to the subprocess
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// completionResponse is received from the subprocess
|
||||
type completionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
@@ -69,7 +61,7 @@ func NewServer(modelName string) (*Server, error) {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the ollama executable path
|
||||
// Get the ollama-mlx executable path (in same directory as current executable)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
@@ -77,11 +69,42 @@ func NewServer(modelName string) (*Server, error) {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
|
||||
|
||||
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
// Append existing LD_LIBRARY_PATH if set
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add LD_LIBRARY_PATH in cmd.Env
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
@@ -105,14 +128,13 @@ func NewServer(modelName string) (*Server, error) {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
// Capture last error line for better error reporting
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting image runner subprocess", "model", modelName, "port", port)
|
||||
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||
}
|
||||
@@ -137,7 +159,6 @@ func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load is called by the scheduler after the server is created.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -170,20 +191,16 @@ func (s *Server) waitUntilRunning() error {
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
// Include last stderr line for better error context
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||
// Include recent stderr lines for better error context
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
case <-ticker.C:
|
||||
@@ -195,44 +212,41 @@ func (s *Server) waitUntilRunning() error {
|
||||
}
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
// getLastErr returns the last stderr line.
|
||||
func (s *Server) getLastErr() string {
|
||||
s.lastErrLock.Lock()
|
||||
defer s.lastErrLock.Unlock()
|
||||
return s.lastErr
|
||||
}
|
||||
|
||||
// Completion generates an image from the prompt via the subprocess.
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
|
||||
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Build request
|
||||
creq := completionRequest{
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Seed: time.Now().UnixNano(),
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
}
|
||||
|
||||
if req.Options != nil {
|
||||
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||
creq.Width = int32(req.Options.NumCtx)
|
||||
}
|
||||
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||
creq.Height = int32(req.Options.NumGPU)
|
||||
}
|
||||
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||
creq.Steps = req.Options.NumPredict
|
||||
}
|
||||
if req.Options.Seed > 0 {
|
||||
creq.Seed = int64(req.Options.Seed)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode request body
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send request to subprocess
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
@@ -247,22 +261,36 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Stream responses
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||
for scanner.Scan() {
|
||||
var cresp completionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||
// Parse subprocess response (has singular "image" field)
|
||||
var raw struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
continue
|
||||
}
|
||||
fn(llm.CompletionResponse{
|
||||
Content: cresp.Content,
|
||||
Done: cresp.Done,
|
||||
})
|
||||
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
if cresp.Done {
|
||||
break
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,22 +332,18 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding is not supported for image generation models.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||
return nil, 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
// Tokenize is not supported for image generation models.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("tokenize not supported for image generation models")
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
// Detokenize is not supported for image generation models.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("detokenize not supported for image generation models")
|
||||
return "", errors.New("not supported")
|
||||
}
|
||||
|
||||
// Pid returns the subprocess PID.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -329,17 +353,9 @@ func (s *Server) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// GetPort returns the subprocess port.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
func (s *Server) GetPort() int { return s.port }
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
|
||||
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns true if the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
|
||||
@@ -45,24 +45,33 @@ func download(ctx context.Context, opts DownloadOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter existing
|
||||
var blobs []Blob
|
||||
// Calculate total from all blobs (for accurate progress reporting on resume)
|
||||
var total int64
|
||||
for _, b := range opts.Blobs {
|
||||
total += b.Size
|
||||
}
|
||||
|
||||
// Filter out already-downloaded blobs and track completed bytes
|
||||
var blobs []Blob
|
||||
var alreadyCompleted int64
|
||||
for _, b := range opts.Blobs {
|
||||
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
|
||||
if opts.Logger != nil {
|
||||
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
|
||||
}
|
||||
alreadyCompleted += b.Size
|
||||
continue
|
||||
}
|
||||
blobs = append(blobs, b)
|
||||
total += b.Size
|
||||
}
|
||||
if len(blobs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
token := opts.Token
|
||||
progress := newProgressTracker(total, opts.Progress)
|
||||
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
|
||||
|
||||
d := &downloader{
|
||||
client: cmp.Or(opts.Client, defaultClient),
|
||||
baseURL: opts.BaseURL,
|
||||
@@ -72,7 +81,7 @@ func download(ctx context.Context, opts DownloadOptions) error {
|
||||
getToken: opts.GetToken,
|
||||
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
||||
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
|
||||
progress: newProgressTracker(total, opts.Progress),
|
||||
progress: progress,
|
||||
speeds: &speedTracker{},
|
||||
logger: opts.Logger,
|
||||
}
|
||||
|
||||
@@ -110,8 +110,6 @@ var defaultClient = &http.Client{
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
Timeout: 5 * time.Minute,
|
||||
// Don't follow redirects automatically - we handle them manually
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
|
||||
@@ -284,6 +284,83 @@ func TestDownloadSkipsExisting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadResumeProgressTotal(t *testing.T) {
|
||||
// Test that when resuming a download with some blobs already present:
|
||||
// 1. Total reflects ALL blob sizes (not just remaining)
|
||||
// 2. Completed starts at the size of already-downloaded blobs
|
||||
serverDir := t.TempDir()
|
||||
blob1, data1 := createTestBlob(t, serverDir, 1000)
|
||||
blob2, data2 := createTestBlob(t, serverDir, 2000)
|
||||
blob3, data3 := createTestBlob(t, serverDir, 3000)
|
||||
|
||||
// Pre-populate client with blob1 and blob2 (simulating partial download)
|
||||
clientDir := t.TempDir()
|
||||
for _, b := range []struct {
|
||||
blob Blob
|
||||
data []byte
|
||||
}{{blob1, data1}, {blob2, data2}} {
|
||||
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(path, b.data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
digest := filepath.Base(r.URL.Path)
|
||||
path := filepath.Join(serverDir, digestToPath(digest))
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var firstCompleted, firstTotal int64
|
||||
var gotFirstProgress bool
|
||||
var mu sync.Mutex
|
||||
|
||||
err := Download(context.Background(), DownloadOptions{
|
||||
Blobs: []Blob{blob1, blob2, blob3},
|
||||
BaseURL: server.URL,
|
||||
DestDir: clientDir,
|
||||
Concurrency: 1,
|
||||
Progress: func(completed, total int64) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if !gotFirstProgress {
|
||||
firstCompleted = completed
|
||||
firstTotal = total
|
||||
gotFirstProgress = true
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
// Total should be sum of ALL blobs, not just blob3
|
||||
expectedTotal := blob1.Size + blob2.Size + blob3.Size
|
||||
if firstTotal != expectedTotal {
|
||||
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
|
||||
}
|
||||
|
||||
// First progress call should show already-completed bytes from blob1+blob2
|
||||
expectedCompleted := blob1.Size + blob2.Size
|
||||
if firstCompleted < expectedCompleted {
|
||||
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
|
||||
}
|
||||
|
||||
// Verify blob3 was downloaded
|
||||
verifyBlob(t, clientDir, blob3, data3)
|
||||
}
|
||||
|
||||
func TestDownloadDigestMismatch(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return wrong data
|
||||
|
||||
284
x/server/show.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// modelConfig represents the HuggingFace config.json structure
|
||||
type modelConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings int `json:"max_position_embeddings"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int `json:"num_key_value_heads"`
|
||||
VocabSize int `json:"vocab_size"`
|
||||
RMSNormEps float64 `json:"rms_norm_eps"`
|
||||
RopeTheta float64 `json:"rope_theta"`
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
TextConfig *struct {
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
MaxPositionEmbeddings int `json:"max_position_embeddings"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
|
||||
// It reads the config.json layer and returns a map compatible with GGML's KV format.
|
||||
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
var config modelConfig
|
||||
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
// Calculate total tensor bytes from manifest layers
|
||||
var totalBytes int64
|
||||
var tensorCount int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
totalBytes += layer.Size
|
||||
tensorCount++
|
||||
}
|
||||
}
|
||||
|
||||
return buildModelInfo(config, totalBytes, tensorCount), nil
|
||||
}
|
||||
|
||||
// buildModelInfo constructs the model info map from config and tensor stats.
|
||||
// This is separated for testability.
|
||||
func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map[string]any {
|
||||
// Determine architecture
|
||||
arch := config.ModelType
|
||||
if arch == "" && len(config.Architectures) > 0 {
|
||||
// Convert HuggingFace architecture name to Ollama format
|
||||
// e.g., "Gemma3ForCausalLM" -> "gemma3"
|
||||
hfArch := config.Architectures[0]
|
||||
arch = strings.ToLower(hfArch)
|
||||
arch = strings.TrimSuffix(arch, "forcausallm")
|
||||
arch = strings.TrimSuffix(arch, "forconditionalgeneration")
|
||||
}
|
||||
|
||||
// Use text_config values if they exist (for multimodal models)
|
||||
hiddenSize := config.HiddenSize
|
||||
maxPosEmbed := config.MaxPositionEmbeddings
|
||||
numLayers := config.NumHiddenLayers
|
||||
|
||||
if config.TextConfig != nil {
|
||||
if config.TextConfig.HiddenSize > 0 {
|
||||
hiddenSize = config.TextConfig.HiddenSize
|
||||
}
|
||||
if config.TextConfig.MaxPositionEmbeddings > 0 {
|
||||
maxPosEmbed = config.TextConfig.MaxPositionEmbeddings
|
||||
}
|
||||
if config.TextConfig.NumHiddenLayers > 0 {
|
||||
numLayers = config.TextConfig.NumHiddenLayers
|
||||
}
|
||||
}
|
||||
|
||||
// Get dtype to determine bytes per parameter for count calculation
|
||||
dtype := config.TorchDtype
|
||||
|
||||
// Determine bytes per parameter based on dtype
|
||||
var bytesPerParam int64 = 2 // default to float16/bfloat16
|
||||
switch strings.ToLower(dtype) {
|
||||
case "float32":
|
||||
bytesPerParam = 4
|
||||
case "float16", "bfloat16":
|
||||
bytesPerParam = 2
|
||||
case "int8", "uint8":
|
||||
bytesPerParam = 1
|
||||
}
|
||||
|
||||
// Subtract safetensors header overhead (88 bytes per tensor file)
|
||||
// Each tensor is stored as a minimal safetensors file
|
||||
totalBytes := totalTensorBytes - tensorCount*88
|
||||
|
||||
paramCount := totalBytes / bytesPerParam
|
||||
|
||||
info := map[string]any{
|
||||
"general.architecture": arch,
|
||||
}
|
||||
|
||||
if maxPosEmbed > 0 {
|
||||
info[fmt.Sprintf("%s.context_length", arch)] = maxPosEmbed
|
||||
}
|
||||
|
||||
if hiddenSize > 0 {
|
||||
info[fmt.Sprintf("%s.embedding_length", arch)] = hiddenSize
|
||||
}
|
||||
|
||||
if numLayers > 0 {
|
||||
info[fmt.Sprintf("%s.block_count", arch)] = numLayers
|
||||
}
|
||||
|
||||
if config.NumAttentionHeads > 0 {
|
||||
info[fmt.Sprintf("%s.attention.head_count", arch)] = config.NumAttentionHeads
|
||||
}
|
||||
|
||||
if config.NumKeyValueHeads > 0 {
|
||||
info[fmt.Sprintf("%s.attention.head_count_kv", arch)] = config.NumKeyValueHeads
|
||||
}
|
||||
|
||||
if config.IntermediateSize > 0 {
|
||||
info[fmt.Sprintf("%s.feed_forward_length", arch)] = config.IntermediateSize
|
||||
}
|
||||
|
||||
if config.VocabSize > 0 {
|
||||
info[fmt.Sprintf("%s.vocab_size", arch)] = config.VocabSize
|
||||
}
|
||||
|
||||
if paramCount > 0 {
|
||||
info["general.parameter_count"] = paramCount
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
return getTensorInfoFromManifest(manifest)
|
||||
}
|
||||
|
||||
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
||||
// This is separated for testability.
|
||||
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
|
||||
var tensors []api.Tensor
|
||||
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the safetensors header from the blob
|
||||
blobPath := manifest.BlobPath(layer.Digest)
|
||||
info, err := readSafetensorsHeader(blobPath)
|
||||
if err != nil {
|
||||
// Skip tensors we can't read
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert shape from int to uint64
|
||||
shape := make([]uint64, len(info.Shape))
|
||||
for i, s := range info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: layer.Name,
|
||||
Type: info.Dtype,
|
||||
Shape: shape,
|
||||
})
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
|
||||
// Otherwise returns the torch_dtype from config.json.
|
||||
func GetSafetensorsDtype(modelName string) (string, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Check if model is quantized by looking for _scale tensors
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
// Model is quantized - return FP8 (affine quantization)
|
||||
return "FP8", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not quantized - return torch_dtype from config.json
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||
return "", fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
return cfg.TorchDtype, nil
|
||||
}
|
||||
|
||||
// safetensorsTensorInfo holds metadata about a tensor from a safetensors header
|
||||
type safetensorsTensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
|
||||
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
|
||||
// Safetensors format: 8-byte header size (little endian) + JSON header + tensor data
|
||||
func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return parseSafetensorsHeader(f)
|
||||
}
|
||||
|
||||
// parseSafetensorsHeader parses a safetensors header from a reader.
|
||||
// This is separated for testability.
|
||||
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
|
||||
// Read header size (8 bytes, little endian)
|
||||
var headerSize uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check - header shouldn't be too large
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header size too large: %d", headerSize)
|
||||
}
|
||||
|
||||
// Read header JSON
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(r, headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
// Parse as map of tensor name -> info
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
// Find the first (and should be only) tensor entry
|
||||
for name, raw := range header {
|
||||
if name == "__metadata__" {
|
||||
continue
|
||||
}
|
||||
var info safetensorsTensorInfo
|
||||
if err := json.Unmarshal(raw, &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no tensor found in header")
|
||||
}
|
||||
597
x/server/show_test.go
Normal file
@@ -0,0 +1,597 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
func TestBuildModelInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config modelConfig
|
||||
totalTensorBytes int64
|
||||
tensorCount int64
|
||||
wantArch string
|
||||
wantContextLen int
|
||||
wantEmbedLen int
|
||||
wantBlockCount int
|
||||
wantParamCount int64
|
||||
}{
|
||||
{
|
||||
name: "gemma3 model with model_type",
|
||||
config: modelConfig{
|
||||
ModelType: "gemma3",
|
||||
HiddenSize: 2560,
|
||||
NumHiddenLayers: 34,
|
||||
MaxPositionEmbeddings: 131072,
|
||||
IntermediateSize: 10240,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 4,
|
||||
VocabSize: 262144,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "gemma3",
|
||||
wantContextLen: 131072,
|
||||
wantEmbedLen: 2560,
|
||||
wantBlockCount: 34,
|
||||
wantParamCount: 4_300_000_000,
|
||||
},
|
||||
{
|
||||
name: "llama model with architectures array",
|
||||
config: modelConfig{
|
||||
Architectures: []string{"LlamaForCausalLM"},
|
||||
HiddenSize: 4096,
|
||||
NumHiddenLayers: 32,
|
||||
MaxPositionEmbeddings: 4096,
|
||||
IntermediateSize: 11008,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 32,
|
||||
VocabSize: 32000,
|
||||
TorchDtype: "float16",
|
||||
},
|
||||
totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "llama",
|
||||
wantContextLen: 4096,
|
||||
wantEmbedLen: 4096,
|
||||
wantBlockCount: 32,
|
||||
wantParamCount: 7_000_000_000,
|
||||
},
|
||||
{
|
||||
name: "multimodal model with text_config",
|
||||
config: modelConfig{
|
||||
Architectures: []string{"Gemma3ForConditionalGeneration"},
|
||||
HiddenSize: 1152, // vision hidden size
|
||||
TextConfig: &struct {
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
MaxPositionEmbeddings int `json:"max_position_embeddings"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
}{
|
||||
HiddenSize: 2560,
|
||||
MaxPositionEmbeddings: 131072,
|
||||
NumHiddenLayers: 34,
|
||||
},
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 4,
|
||||
VocabSize: 262144,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 8_600_000_088,
|
||||
tensorCount: 1,
|
||||
wantArch: "gemma3",
|
||||
wantContextLen: 131072,
|
||||
wantEmbedLen: 2560,
|
||||
wantBlockCount: 34,
|
||||
wantParamCount: 4_300_000_000,
|
||||
},
|
||||
{
|
||||
name: "float32 model",
|
||||
config: modelConfig{
|
||||
ModelType: "test",
|
||||
HiddenSize: 512,
|
||||
NumHiddenLayers: 6,
|
||||
MaxPositionEmbeddings: 2048,
|
||||
TorchDtype: "float32",
|
||||
},
|
||||
totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "test",
|
||||
wantContextLen: 2048,
|
||||
wantEmbedLen: 512,
|
||||
wantBlockCount: 6,
|
||||
wantParamCount: 100_000_000,
|
||||
},
|
||||
{
|
||||
name: "multiple tensors with header overhead",
|
||||
config: modelConfig{
|
||||
ModelType: "test",
|
||||
HiddenSize: 256,
|
||||
NumHiddenLayers: 4,
|
||||
MaxPositionEmbeddings: 1024,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes
|
||||
tensorCount: 10,
|
||||
wantArch: "test",
|
||||
wantContextLen: 1024,
|
||||
wantEmbedLen: 256,
|
||||
wantBlockCount: 4,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info := buildModelInfo(tt.config, tt.totalTensorBytes, tt.tensorCount)
|
||||
|
||||
// Check architecture
|
||||
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
|
||||
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
|
||||
}
|
||||
|
||||
// Check context length
|
||||
contextKey := tt.wantArch + ".context_length"
|
||||
if contextLen, ok := info[contextKey].(int); !ok || contextLen != tt.wantContextLen {
|
||||
t.Errorf("context_length = %v, want %v", info[contextKey], tt.wantContextLen)
|
||||
}
|
||||
|
||||
// Check embedding length
|
||||
embedKey := tt.wantArch + ".embedding_length"
|
||||
if embedLen, ok := info[embedKey].(int); !ok || embedLen != tt.wantEmbedLen {
|
||||
t.Errorf("embedding_length = %v, want %v", info[embedKey], tt.wantEmbedLen)
|
||||
}
|
||||
|
||||
// Check block count
|
||||
blockKey := tt.wantArch + ".block_count"
|
||||
if blockCount, ok := info[blockKey].(int); !ok || blockCount != tt.wantBlockCount {
|
||||
t.Errorf("block_count = %v, want %v", info[blockKey], tt.wantBlockCount)
|
||||
}
|
||||
|
||||
// Check parameter count
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
|
||||
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelInfo_ArchitectureConversion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
architectures []string
|
||||
modelType string
|
||||
wantArch string
|
||||
}{
|
||||
{
|
||||
name: "LlamaForCausalLM",
|
||||
architectures: []string{"LlamaForCausalLM"},
|
||||
wantArch: "llama",
|
||||
},
|
||||
{
|
||||
name: "Gemma3ForCausalLM",
|
||||
architectures: []string{"Gemma3ForCausalLM"},
|
||||
wantArch: "gemma3",
|
||||
},
|
||||
{
|
||||
name: "Gemma3ForConditionalGeneration",
|
||||
architectures: []string{"Gemma3ForConditionalGeneration"},
|
||||
wantArch: "gemma3",
|
||||
},
|
||||
{
|
||||
name: "Qwen2ForCausalLM",
|
||||
architectures: []string{"Qwen2ForCausalLM"},
|
||||
wantArch: "qwen2",
|
||||
},
|
||||
{
|
||||
name: "model_type takes precedence",
|
||||
architectures: []string{"LlamaForCausalLM"},
|
||||
modelType: "custom",
|
||||
wantArch: "custom",
|
||||
},
|
||||
{
|
||||
name: "empty architectures with model_type",
|
||||
architectures: nil,
|
||||
modelType: "mymodel",
|
||||
wantArch: "mymodel",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := modelConfig{
|
||||
Architectures: tt.architectures,
|
||||
ModelType: tt.modelType,
|
||||
}
|
||||
info := buildModelInfo(config, 0, 0)
|
||||
|
||||
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
|
||||
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelInfo_BytesPerParam(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dtype string
|
||||
totalBytes int64
|
||||
tensorCount int64
|
||||
wantParamCount int64
|
||||
}{
|
||||
{
|
||||
name: "bfloat16",
|
||||
dtype: "bfloat16",
|
||||
totalBytes: 2_000_088, // 1M * 2 + 88
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "float16",
|
||||
dtype: "float16",
|
||||
totalBytes: 2_000_088,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "float32",
|
||||
dtype: "float32",
|
||||
totalBytes: 4_000_088, // 1M * 4 + 88
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "int8",
|
||||
dtype: "int8",
|
||||
totalBytes: 1_000_088, // 1M * 1 + 88
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "unknown dtype defaults to 2 bytes",
|
||||
dtype: "unknown",
|
||||
totalBytes: 2_000_088,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "empty dtype defaults to 2 bytes",
|
||||
dtype: "",
|
||||
totalBytes: 2_000_088,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := modelConfig{
|
||||
ModelType: "test",
|
||||
TorchDtype: tt.dtype,
|
||||
}
|
||||
info := buildModelInfo(config, tt.totalBytes, tt.tensorCount)
|
||||
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
|
||||
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header map[string]any
|
||||
wantDtype string
|
||||
wantShape []int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple tensor",
|
||||
header: map[string]any{
|
||||
"weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 262144},
|
||||
"data_offsets": []int64{0, 1342177280},
|
||||
},
|
||||
},
|
||||
wantDtype: "BF16",
|
||||
wantShape: []int64{2560, 262144},
|
||||
},
|
||||
{
|
||||
name: "with metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"format": "pt",
|
||||
},
|
||||
"bias": map[string]any{
|
||||
"dtype": "F32",
|
||||
"shape": []int64{1024},
|
||||
"data_offsets": []int64{0, 4096},
|
||||
},
|
||||
},
|
||||
wantDtype: "F32",
|
||||
wantShape: []int64{1024},
|
||||
},
|
||||
{
|
||||
name: "float16 tensor",
|
||||
header: map[string]any{
|
||||
"layer.weight": map[string]any{
|
||||
"dtype": "F16",
|
||||
"shape": []int64{512, 512, 3, 3},
|
||||
"data_offsets": []int64{0, 4718592},
|
||||
},
|
||||
},
|
||||
wantDtype: "F16",
|
||||
wantShape: []int64{512, 512, 3, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create safetensors format: 8-byte size + JSON header
|
||||
headerJSON, err := json.Marshal(tt.header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
buf.Write(headerJSON)
|
||||
|
||||
info, err := parseSafetensorsHeader(&buf)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseSafetensorsHeader() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if info.Dtype != tt.wantDtype {
|
||||
t.Errorf("Dtype = %v, want %v", info.Dtype, tt.wantDtype)
|
||||
}
|
||||
|
||||
if len(info.Shape) != len(tt.wantShape) {
|
||||
t.Errorf("Shape length = %v, want %v", len(info.Shape), len(tt.wantShape))
|
||||
} else {
|
||||
for i, s := range info.Shape {
|
||||
if s != tt.wantShape[i] {
|
||||
t.Errorf("Shape[%d] = %v, want %v", i, s, tt.wantShape[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
wantErr: "failed to read header size",
|
||||
},
|
||||
{
|
||||
name: "truncated header size",
|
||||
data: []byte{0x01, 0x02, 0x03},
|
||||
wantErr: "failed to read header size",
|
||||
},
|
||||
{
|
||||
name: "header size too large",
|
||||
data: func() []byte {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(2*1024*1024)) // 2MB
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "header size too large",
|
||||
},
|
||||
{
|
||||
name: "truncated header",
|
||||
data: func() []byte {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(100))
|
||||
buf.Write([]byte("short"))
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "failed to read header",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
data: func() []byte {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(10))
|
||||
buf.Write([]byte("not json!!"))
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "failed to parse header",
|
||||
},
|
||||
{
|
||||
name: "no tensors in header",
|
||||
data: func() []byte {
|
||||
header := map[string]any{
|
||||
"__metadata__": map[string]any{"format": "pt"},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "no tensor found in header",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := parseSafetensorsHeader(bytes.NewReader(tt.data))
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !bytes.Contains([]byte(err.Error()), []byte(tt.wantErr)) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
// Create a temp directory for blobs
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create test tensor blobs
|
||||
tensors := []struct {
|
||||
name string
|
||||
digest string
|
||||
dtype string
|
||||
shape []int64
|
||||
}{
|
||||
{
|
||||
name: "model.embed_tokens.weight",
|
||||
digest: "sha256:abc123",
|
||||
dtype: "BF16",
|
||||
shape: []int64{262144, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.layers.0.self_attn.q_proj.weight",
|
||||
digest: "sha256:def456",
|
||||
dtype: "BF16",
|
||||
shape: []int64{2560, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.norm.weight",
|
||||
digest: "sha256:ghi789",
|
||||
dtype: "F32",
|
||||
shape: []int64{2560},
|
||||
},
|
||||
}
|
||||
|
||||
// Create blob files
|
||||
var layers []imagegen.ManifestLayer
|
||||
for _, tensor := range tensors {
|
||||
// Create safetensors blob
|
||||
header := map[string]any{
|
||||
tensor.name: map[string]any{
|
||||
"dtype": tensor.dtype,
|
||||
"shape": tensor.shape,
|
||||
"data_offsets": []int64{0, 1000},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
// Write blob file
|
||||
blobName := "sha256-" + tensor.digest[7:]
|
||||
blobPath := filepath.Join(tempDir, blobName)
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: tensor.digest,
|
||||
Size: int64(buf.Len() + 1000), // header + fake data
|
||||
Name: tensor.name,
|
||||
})
|
||||
}
|
||||
|
||||
// Add a non-tensor layer (should be skipped)
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
MediaType: "application/vnd.ollama.image.json",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
Name: "config.json",
|
||||
})
|
||||
|
||||
manifest := &imagegen.ModelManifest{
|
||||
Manifest: &imagegen.Manifest{
|
||||
Layers: layers,
|
||||
},
|
||||
BlobDir: tempDir,
|
||||
}
|
||||
|
||||
result, err := getTensorInfoFromManifest(manifest)
|
||||
if err != nil {
|
||||
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Errorf("got %d tensors, want 3", len(result))
|
||||
}
|
||||
|
||||
// Verify each tensor
|
||||
for i, tensor := range tensors {
|
||||
if i >= len(result) {
|
||||
break
|
||||
}
|
||||
if result[i].Name != tensor.name {
|
||||
t.Errorf("tensor[%d].Name = %v, want %v", i, result[i].Name, tensor.name)
|
||||
}
|
||||
if result[i].Type != tensor.dtype {
|
||||
t.Errorf("tensor[%d].Type = %v, want %v", i, result[i].Type, tensor.dtype)
|
||||
}
|
||||
if len(result[i].Shape) != len(tensor.shape) {
|
||||
t.Errorf("tensor[%d].Shape length = %v, want %v", i, len(result[i].Shape), len(tensor.shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSafetensorsHeader(t *testing.T) {
|
||||
// Create a temp file with a valid safetensors header
|
||||
tempDir := t.TempDir()
|
||||
|
||||
header := map[string]any{
|
||||
"test_tensor": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{1024, 768},
|
||||
"data_offsets": []int64{0, 1572864},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
filePath := filepath.Join(tempDir, "test.safetensors")
|
||||
if err := os.WriteFile(filePath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
info, err := readSafetensorsHeader(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("readSafetensorsHeader() error = %v", err)
|
||||
}
|
||||
|
||||
if info.Dtype != "BF16" {
|
||||
t.Errorf("Dtype = %v, want BF16", info.Dtype)
|
||||
}
|
||||
if len(info.Shape) != 2 || info.Shape[0] != 1024 || info.Shape[1] != 768 {
|
||||
t.Errorf("Shape = %v, want [1024, 768]", info.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSafetensorsHeader_FileNotFound(t *testing.T) {
|
||||
_, err := readSafetensorsHeader("/nonexistent/path/file.safetensors")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
@@ -54,6 +54,16 @@ func (r *Registry) RegisterBash() {
|
||||
r.Register(&BashTool{})
|
||||
}
|
||||
|
||||
// RegisterWebSearch adds the web search tool to the registry.
|
||||
func (r *Registry) RegisterWebSearch() {
|
||||
r.Register(&WebSearchTool{})
|
||||
}
|
||||
|
||||
// RegisterWebFetch adds the web fetch tool to the registry.
|
||||
func (r *Registry) RegisterWebFetch() {
|
||||
r.Register(&WebFetchTool{})
|
||||
}
|
||||
|
||||
// Get retrieves a tool by name.
|
||||
func (r *Registry) Get(name string) (Tool, bool) {
|
||||
tool, ok := r.tools[name]
|
||||
|
||||
162
x/tools/webfetch.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
webFetchAPI = "https://ollama.com/api/web_fetch"
|
||||
webFetchTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// ErrWebFetchAuthRequired is returned when web fetch requires authentication
|
||||
var ErrWebFetchAuthRequired = errors.New("web fetch requires authentication")
|
||||
|
||||
// WebFetchTool implements web page fetching using Ollama's hosted API.
|
||||
type WebFetchTool struct{}
|
||||
|
||||
// Name returns the tool name.
|
||||
func (w *WebFetchTool) Name() string {
|
||||
return "web_fetch"
|
||||
}
|
||||
|
||||
// Description returns a description of the tool.
|
||||
func (w *WebFetchTool) Description() string {
|
||||
return "Fetch and extract text content from a web page. Use this to read the full content of a URL found in search results or provided by the user."
|
||||
}
|
||||
|
||||
// Schema returns the tool's parameter schema.
|
||||
func (w *WebFetchTool) Schema() api.ToolFunction {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("url", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The URL to fetch and extract content from",
|
||||
})
|
||||
return api.ToolFunction{
|
||||
Name: w.Name(),
|
||||
Description: w.Description(),
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"url"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// webFetchRequest is the request body for the web fetch API.
|
||||
type webFetchRequest struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// webFetchResponse is the response from the web fetch API.
|
||||
type webFetchResponse struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Links []string `json:"links,omitempty"`
|
||||
}
|
||||
|
||||
// Execute fetches content from a web page.
|
||||
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
|
||||
func (w *WebFetchTool) Execute(args map[string]any) (string, error) {
|
||||
urlStr, ok := args["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return "", fmt.Errorf("url parameter is required")
|
||||
}
|
||||
|
||||
// Validate URL
|
||||
if _, err := url.Parse(urlStr); err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// Prepare request
|
||||
reqBody := webFetchRequest{
|
||||
URL: urlStr,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
// Parse URL and add timestamp for signing
|
||||
fetchURL, err := url.Parse(webFetchAPI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parsing fetch URL: %w", err)
|
||||
}
|
||||
|
||||
q := fetchURL.Query()
|
||||
q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
fetchURL.RawQuery = q.Encode()
|
||||
|
||||
// Sign the request using Ollama key (~/.ollama/id_ed25519)
|
||||
ctx := context.Background()
|
||||
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, fetchURL.RequestURI())
|
||||
signature, err := auth.Sign(ctx, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("signing request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fetchURL.String(), bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if signature != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||
}
|
||||
|
||||
// Send request
|
||||
client := &http.Client{Timeout: webFetchTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sending request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return "", ErrWebFetchAuthRequired
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("web fetch API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var fetchResp webFetchResponse
|
||||
if err := json.Unmarshal(body, &fetchResp); err != nil {
|
||||
return "", fmt.Errorf("parsing response: %w", err)
|
||||
}
|
||||
|
||||
// Format result
|
||||
var sb strings.Builder
|
||||
if fetchResp.Title != "" {
|
||||
sb.WriteString(fmt.Sprintf("Title: %s\n\n", fetchResp.Title))
|
||||
}
|
||||
|
||||
if fetchResp.Content != "" {
|
||||
sb.WriteString("Content:\n")
|
||||
sb.WriteString(fetchResp.Content)
|
||||
} else {
|
||||
sb.WriteString("No content could be extracted from the page.")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||