Compare commits

..

1 Commits

Author SHA1 Message Date
Bruce MacDonald
d5a2849d1f cmd: set codex env vars on launch and handle zstd request bodies
The Codex runner was not setting OPENAI_BASE_URL or OPENAI_API_KEY, this prevents Codex from sending requests to api.openai.com instead of the local Ollama server. This mirrors the approach used by the Claude runner.

Codex v0.98.0 sends zstd-compressed request bodies to the /v1/responses endpoint. Add decompression support in ResponsesMiddleware with an 8MB max decompressed size limit to prevent resource exhaustion.
2026-02-06 10:42:18 -08:00
11 changed files with 331 additions and 402 deletions

View File

@@ -147,7 +147,7 @@ ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/imagegen/mlx x/imagegen/mlx
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local

View File

@@ -897,5 +897,11 @@ func countContentBlock(block any) int {
}
}
if source, ok := blockMap["source"].(map[string]any); ok {
if data, ok := source["data"].(string); ok {
total += len(data)
}
}
return total
}

View File

@@ -1826,18 +1826,6 @@ func NewCLI() *cobra.Command {
return
}
// If no args, run launch to show interactive app selector
if len(args) == 0 {
if err := checkServerHeartbeat(cmd, args); err != nil {
cobra.CheckErr(err)
return
}
if err := config.RunLaunch(cmd, args, "", false); err != nil {
cobra.CheckErr(err)
}
return
}
cmd.Print(cmd.UsageString())
},
}

View File

@@ -6,6 +6,7 @@ import (
"os/exec"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}

View File

@@ -59,12 +59,6 @@ var integrations = map[string]Runner{
"openclaw": &Openclaw{},
}
// IsIntegration returns true if the given name is a valid integration.
func IsIntegration(name string) bool {
_, ok := integrations[strings.ToLower(name)]
return ok
}
// recommendedModels are shown when the user has no models or as suggestions.
// Order matters: local models first, then cloud models.
var recommendedModels = []selectItem{
@@ -82,7 +76,7 @@ var integrationAliases = map[string]bool{
func selectIntegration() (string, error) {
if len(integrations) == 0 {
return "", fmt.Errorf("no apps available")
return "", fmt.Errorf("no integrations available")
}
names := slices.Sorted(maps.Keys(integrations))
@@ -99,14 +93,14 @@ func selectIntegration() (string, error) {
items = append(items, selectItem{Name: name, Description: description})
}
return selectPrompt("Select app:", items)
return selectPrompt("Select integration:", items)
}
// selectModels lets the user select models for an integration
func selectModels(ctx context.Context, name, current string) ([]string, error) {
r, ok := integrations[name]
if !ok {
return nil, fmt.Errorf("unknown app: %s", name)
return nil, fmt.Errorf("unknown integration: %s", name)
}
client, err := api.ClientFromEnvironment()
@@ -200,20 +194,6 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
return nil
}
// showOrPull checks if a model exists via client.Show and offers to pull it if not found.
func showOrPull(ctx context.Context, client *api.Client, model string) error {
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
return nil
}
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
return err
} else if !ok {
return errCancelled
}
fmt.Fprintf(os.Stderr, "\n")
return pullModel(ctx, client, model)
}
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -312,7 +292,7 @@ func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]
func runIntegration(name, modelName string, args []string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown app: %s", name)
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
@@ -341,230 +321,17 @@ func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, na
return saveAliases(name, aliases)
}
// RunLaunch executes the launch logic for the given integration and arguments.
// This can be called directly from the root command (with empty modelFlag/configFlag)
// or via the launch subcommand.
func RunLaunch(cmd *cobra.Command, args []string, modelFlag string, configFlag bool) error {
// Extract integration name and args to pass through using -- separator
var name string
var passArgs []string
dashIdx := cmd.ArgsLenAtDash()
if dashIdx == -1 {
// No "--" separator: only allow 0 or 1 args (integration name)
if len(args) > 1 {
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the app", args[1:])
}
if len(args) == 1 {
name = args[0]
}
} else {
// "--" was used: args before it = integration name, args after = passthrough
if dashIdx > 1 {
return fmt.Errorf("expected at most 1 app name before '--', got %d", dashIdx)
}
if dashIdx == 1 {
name = args[0]
}
passArgs = args[dashIdx:]
}
if name == "" {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
r, ok := integrations[strings.ToLower(name)]
if !ok {
return fmt.Errorf("unknown app: %s", name)
}
// Handle AliasConfigurer integrations (claude, codex)
if ac, ok := r.(AliasConfigurer); ok {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Validate --model flag if provided
if modelFlag != "" {
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
if errors.Is(err, errCancelled) {
return nil
}
return err
}
}
var model string
var existingAliases map[string]string
// Load saved config
if cfg, err := loadIntegration(name); err == nil {
existingAliases = cfg.Aliases
if len(cfg.Models) > 0 {
model = cfg.Models[0]
// AliasConfigurer integrations use single model; sanitize if multiple
if len(cfg.Models) > 1 {
_ = saveIntegration(name, []string{model})
}
}
}
// --model flag overrides saved model
if modelFlag != "" {
model = modelFlag
}
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
if err := showOrPull(cmd.Context(), client, model); err != nil {
model = ""
}
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
}
// Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) {
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
return err
}
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
}
if err := saveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
// Launch (unless --config without confirmation)
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
return runIntegration(name, model, passArgs)
}
return nil
}
return runIntegration(name, model, passArgs)
}
// Validate --model flag for non-AliasConfigurer integrations
if modelFlag != "" {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
if errors.Is(err, errCancelled) {
return nil
}
return err
}
}
var models []string
if modelFlag != "" {
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
if m != modelFlag {
models = append(models, m)
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if editor, isEditor := r.(Editor); isEditor {
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if _, isEditor := r.(Editor); isEditor {
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
} else {
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
}
}
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0], passArgs)
}
return nil
}
if runner, isRunner := r.(Runner); isRunner {
return runner.Run(models[0], passArgs)
}
return nil
}
// LaunchCmd returns the cobra command for launching integrations.
func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error) *cobra.Command {
var modelFlag string
var configFlag bool
cmd := &cobra.Command{
Use: "launch [APP] [-- [EXTRA_ARGS...]]",
Short: "Launch an app with Ollama",
Long: `Launch an app configured with Ollama models.
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
Supported apps:
Supported integrations:
claude Claude Code
codex Codex
droid Droid
@@ -576,12 +343,200 @@ Examples:
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)
ollama launch codex -- -p myprofile (pass extra args to app)
ollama launch codex -- -p myprofile (pass extra args to integration)
ollama launch codex -- --sandbox workspace-write`,
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
return RunLaunch(cmd, args, modelFlag, configFlag)
// Extract integration name and args to pass through using -- separator
var name string
var passArgs []string
dashIdx := cmd.ArgsLenAtDash()
if dashIdx == -1 {
// No "--" separator: only allow 0 or 1 args (integration name)
if len(args) > 1 {
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
}
if len(args) == 1 {
name = args[0]
}
} else {
// "--" was used: args before it = integration name, args after = passthrough
if dashIdx > 1 {
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
}
if dashIdx == 1 {
name = args[0]
}
passArgs = args[dashIdx:]
}
if name == "" {
var err error
name, err = selectIntegration()
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
r, ok := integrations[strings.ToLower(name)]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
// Handle AliasConfigurer integrations (claude, codex)
if ac, ok := r.(AliasConfigurer); ok {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Validate --model flag if provided
if modelFlag != "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
}
}
var model string
var existingAliases map[string]string
// Load saved config
if cfg, err := loadIntegration(name); err == nil {
existingAliases = cfg.Aliases
if len(cfg.Models) > 0 {
model = cfg.Models[0]
// AliasConfigurer integrations use single model; sanitize if multiple
if len(cfg.Models) > 1 {
_ = saveIntegration(name, []string{model})
}
}
}
// --model flag overrides saved model
if modelFlag != "" {
model = modelFlag
}
// Validate saved model still exists
if model != "" && modelFlag == "" {
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
model = ""
}
}
// If no valid model or --config flag, show picker
if model == "" || configFlag {
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag)
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
}
// Sync aliases and save
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
}
if err := saveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
// Launch (unless --config without confirmation)
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
return runIntegration(name, model, passArgs)
}
return nil
}
return runIntegration(name, model, passArgs)
}
// Validate --model flag for non-AliasConfigurer integrations
if modelFlag != "" {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
return fmt.Errorf("model %q not found", modelFlag)
}
}
var models []string
if modelFlag != "" {
models = []string{modelFlag}
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
for _, m := range existing.Models {
if m != modelFlag {
models = append(models, m)
}
}
}
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
return runIntegration(name, saved.Models[0], passArgs)
} else {
var err error
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
}
if editor, isEditor := r.(Editor); isEditor {
paths := editor.Paths()
if len(paths) > 0 {
fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
for _, p := range paths {
fmt.Fprintf(os.Stderr, " %s\n", p)
}
fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
if ok, _ := confirmPrompt("Proceed?"); !ok {
return nil
}
}
}
if err := saveIntegration(name, models); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
if editor, isEditor := r.(Editor); isEditor {
if err := editor.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
}
if _, isEditor := r.(Editor); isEditor {
if len(models) == 1 {
fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
} else {
fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
}
}
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0], passArgs)
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
return runIntegration(name, models[0], passArgs)
},
}
@@ -695,7 +650,7 @@ func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
if client == nil {
return false
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
if err != nil {
return false
}

View File

@@ -2,17 +2,12 @@ package config
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/spf13/cobra"
)
@@ -98,8 +93,8 @@ func TestLaunchCmd(t *testing.T) {
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [APP] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [APP] [-- [EXTRA_ARGS...]]")
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
@@ -133,8 +128,8 @@ func TestRunIntegration_UnknownIntegration(t *testing.T) {
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
if !strings.Contains(err.Error(), "unknown app") {
t.Errorf("error should mention 'unknown app', got: %v", err)
if !strings.Contains(err.Error(), "unknown integration") {
t.Errorf("error should mention 'unknown integration', got: %v", err)
}
}
@@ -544,136 +539,3 @@ func TestAliasConfigurerInterface(t *testing.T) {
}
})
}
func TestShowOrPull_ModelExists(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"model":"test-model"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := showOrPull(context.Background(), client, "test-model")
if err != nil {
t.Errorf("showOrPull should return nil when model exists, got: %v", err)
}
}
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"model not found"}`)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
// confirmPrompt will fail in test (no terminal), so showOrPull should return an error
err := showOrPull(context.Background(), client, "missing-model")
if err == nil {
t.Error("showOrPull should return error when model not found and no terminal available")
}
}
func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
var receivedModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
receivedModel = req.Model
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"model":"%s"}`, receivedModel)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
_ = showOrPull(context.Background(), client, "qwen3:8b")
if receivedModel != "qwen3:8b" {
t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel)
}
}
func TestEnsureAuth_NoCloudModels(t *testing.T) {
// ensureAuth should be a no-op when no cloud models are selected
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("no API calls expected when no cloud models selected")
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
if err != nil {
t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
}
}
func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
// ensureAuth should only care about models in cloudModels map
var whoamiCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/me" {
whoamiCalled = true
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"name":"testuser"}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cloudModels := map[string]bool{"cloud-model:cloud": true}
selected := []string{"cloud-model:cloud", "local-model"}
err := ensureAuth(context.Background(), client, cloudModels, selected)
if err != nil {
t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
}
if !whoamiCalled {
t.Error("expected whoami to be called for cloud model")
}
}
func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
var whoamiCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/me" {
whoamiCalled = true
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
// cloudModels has entries but none are in selected
cloudModels := map[string]bool{"cloud-model:cloud": true}
selected := []string{"local-model"}
err := ensureAuth(context.Background(), client, cloudModels, selected)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if whoamiCalled {
t.Error("whoami should not be called when no cloud models are selected")
}
}

View File

@@ -312,7 +312,7 @@ Parallel request processing for a given model results in increasing the context
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

1
go.mod
View File

@@ -24,6 +24,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/klauspost/compress v1.18.3
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

4
go.sum
View File

@@ -106,7 +106,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
@@ -134,8 +133,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
func ResponsesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Content-Encoding") == "zstd" {
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
return
}
defer reader.Close()
c.Request.Body = io.NopCloser(reader)
c.Request.Header.Del("Content-Encoding")
}
var req openai.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))

View File

@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -1238,3 +1239,102 @@ func TestImageEditsMiddleware(t *testing.T) {
})
}
}
func zstdCompress(t *testing.T, data []byte) []byte {
t.Helper()
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(data); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return buf.Bytes()
}
func TestResponsesMiddlewareZstd(t *testing.T) {
tests := []struct {
name string
body string
useZstd bool
oversized bool
wantCode int
wantModel string
wantMessage string
}{
{
name: "plain JSON",
body: `{"model": "test-model", "input": "Hello"}`,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd compressed",
body: `{"model": "test-model", "input": "Hello"}`,
useZstd: true,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd over max decompressed size",
oversized: true,
useZstd: true,
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedRequest *api.ChatRequest
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
c.Status(http.StatusOK)
})
var bodyReader io.Reader
if tt.oversized {
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
} else if tt.useZstd {
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
} else {
bodyReader = strings.NewReader(tt.body)
}
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
req.Header.Set("Content-Type", "application/json")
if tt.useZstd || tt.oversized {
req.Header.Set("Content-Encoding", "zstd")
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.wantCode {
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
}
if tt.wantCode != http.StatusOK {
return
}
if capturedRequest == nil {
t.Fatal("expected captured request, got nil")
}
if capturedRequest.Model != tt.wantModel {
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
}
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
}
})
}
}