Compare commits

..

1 Commits

Author SHA1 Message Date
Bruce MacDonald
f92a82db15 app: match model picker to server models
Rather than adding models to the model picker to guide users on first use, take the ollama tags response as a source of truth.
2026-02-05 15:43:22 -08:00
9 changed files with 42 additions and 633 deletions

View File

@@ -1,13 +1,12 @@
import { useQuery } from "@tanstack/react-query";
import { Model } from "@/gotypes";
import { getModels } from "@/api";
import { mergeModels } from "@/utils/mergeModels";
import { useSettings } from "./useSettings";
import { useMemo } from "react";
const DEFAULT_MODEL = "gemma3:4b";
export function useModels(searchQuery = "") {
const { settings } = useSettings();
const localQuery = useQuery<Model[], Error>({
const query = useQuery<Model[], Error>({
queryKey: ["models", searchQuery],
queryFn: () => getModels(searchQuery),
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
@@ -19,33 +18,18 @@ export function useModels(searchQuery = "") {
refetchIntervalInBackground: true,
});
const allModels = useMemo(() => {
const models = mergeModels(localQuery.data || [], settings.airplaneMode);
if (searchQuery && searchQuery.trim()) {
const query = searchQuery.toLowerCase().trim();
const filteredModels = models.filter((model) =>
model.model.toLowerCase().includes(query),
);
const seen = new Set<string>();
return filteredModels.filter((model) => {
const currentModel = model.model.toLowerCase();
if (seen.has(currentModel)) {
return false;
}
seen.add(currentModel);
return true;
});
const models = useMemo(() => {
const data = query.data || [];
if (data.length === 0) {
return [new Model({ model: DEFAULT_MODEL })];
}
return models;
}, [localQuery.data, searchQuery, settings.airplaneMode]);
return data;
}, [query.data]);
return {
...localQuery,
data: allModels,
isLoading: localQuery.isLoading,
...query,
data: models,
isLoading: query.isLoading,
};
}

View File

@@ -4,7 +4,6 @@ import { useModels } from "./useModels";
import { useChat } from "./useChats";
import { useSettings } from "./useSettings.ts";
import { Model } from "@/gotypes";
import { FEATURED_MODELS } from "@/utils/mergeModels";
import { getTotalVRAM } from "@/utils/vram.ts";
import { getInferenceCompute } from "@/api";
@@ -46,77 +45,13 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
const restoredChatRef = useRef<string | null>(null);
const selectedModel: Model | null = useMemo(() => {
// if airplane mode is on and selected model ends with cloud,
// switch to recommended default model
if (settings.airplaneMode && settings.selectedModel?.endsWith("cloud")) {
return (
models.find((m) => m.model === recommendedModel) ||
models.find((m) => m.isCloud) ||
models.find((m) => m.digest === undefined || m.digest === "") ||
models[0] ||
null
);
}
// Migration logic: if turboEnabled is true and selectedModel is a base model,
// migrate to the cloud version and disable turboEnabled permanently
// TODO: remove this logic in a future release
const baseModelsToMigrate = [
"gpt-oss:20b",
"gpt-oss:120b",
"deepseek-v3.1:671b",
"qwen3-coder:480b",
];
const shouldMigrate =
!settings.airplaneMode &&
settings.turboEnabled &&
baseModelsToMigrate.includes(settings.selectedModel);
if (shouldMigrate) {
const cloudModel = `${settings.selectedModel}cloud`;
return (
models.find((m) => m.model === cloudModel) ||
new Model({
model: cloudModel,
cloud: true,
ollama_host: false,
})
);
}
return (
models.find((m) => m.model === settings.selectedModel) ||
(settings.selectedModel &&
new Model({
model: settings.selectedModel,
cloud: FEATURED_MODELS.some(
(f) => f.endsWith("cloud") && f === settings.selectedModel,
),
ollama_host: false,
})) ||
new Model({ model: settings.selectedModel })) ||
null
);
}, [models, settings.selectedModel, settings.airplaneMode, recommendedModel]);
useEffect(() => {
if (!selectedModel) return;
if (
settings.airplaneMode &&
settings.selectedModel?.endsWith("cloud") &&
selectedModel.model !== settings.selectedModel
) {
setSettings({ SelectedModel: selectedModel.model });
}
if (
!settings.airplaneMode &&
settings.turboEnabled &&
selectedModel.model !== settings.selectedModel
) {
setSettings({ SelectedModel: selectedModel.model, TurboEnabled: false });
}
}, [selectedModel, settings.airplaneMode, settings.selectedModel]);
}, [models, settings.selectedModel]);
// Set model from chat history when chat data loads
useEffect(() => {
@@ -169,8 +104,6 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
const defaultModel =
models.find((m) => m.model === recommendedModel) ||
models.find((m) => m.isCloud()) ||
models.find((m) => m.digest === undefined || m.digest === "") ||
models[0];
if (defaultModel) {

View File

@@ -1,128 +0,0 @@
import { describe, it, expect } from "vitest";
import { Model } from "@/gotypes";
import { mergeModels, FEATURED_MODELS } from "@/utils/mergeModels";
import "@/api";
describe("Model merging logic", () => {
it("should handle cloud models with -cloud suffix", () => {
const localModels: Model[] = [
new Model({ model: "gpt-oss:120b-cloud" }),
new Model({ model: "llama3:latest" }),
new Model({ model: "mistral:latest" }),
];
const merged = mergeModels(localModels);
// First verify cloud models are first and in FEATURED_MODELS order
const cloudModels = FEATURED_MODELS.filter((m: string) =>
m.endsWith("cloud"),
);
for (let i = 0; i < cloudModels.length; i++) {
expect(merged[i].model).toBe(cloudModels[i]);
expect(merged[i].isCloud()).toBe(true);
}
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
const nonCloudFeatured = FEATURED_MODELS.filter(
(m: string) => !m.endsWith("cloud"),
);
for (let i = 0; i < nonCloudFeatured.length; i++) {
const model = merged[i + cloudModels.length];
expect(model.model).toBe(nonCloudFeatured[i]);
expect(model.isCloud()).toBe(false);
}
// Verify local models are preserved and come after featured models
const featuredCount = FEATURED_MODELS.length;
expect(merged[featuredCount].model).toBe("llama3:latest");
expect(merged[featuredCount + 1].model).toBe("mistral:latest");
// Length should be exactly featured models plus our local models
expect(merged.length).toBe(FEATURED_MODELS.length + 2);
});
it("should hide cloud models in airplane mode", () => {
const localModels: Model[] = [
new Model({ model: "gpt-oss:120b-cloud" }),
new Model({ model: "llama3:latest" }),
new Model({ model: "mistral:latest" }),
];
const merged = mergeModels(localModels, true); // airplane mode = true
// No cloud models should be present
const cloudModels = merged.filter((m) => m.isCloud());
expect(cloudModels.length).toBe(0);
// Should have non-cloud featured models
const nonCloudFeatured = FEATURED_MODELS.filter(
(m) => !m.endsWith("cloud"),
);
for (let i = 0; i < nonCloudFeatured.length; i++) {
const model = merged[i];
expect(model.model).toBe(nonCloudFeatured[i]);
expect(model.isCloud()).toBe(false);
}
// Local models should be preserved
const featuredCount = nonCloudFeatured.length;
expect(merged[featuredCount].model).toBe("llama3:latest");
expect(merged[featuredCount + 1].model).toBe("mistral:latest");
});
it("should handle empty input", () => {
const merged = mergeModels([]);
// First verify cloud models are first and in FEATURED_MODELS order
const cloudModels = FEATURED_MODELS.filter((m) => m.endsWith("cloud"));
for (let i = 0; i < cloudModels.length; i++) {
expect(merged[i].model).toBe(cloudModels[i]);
expect(merged[i].isCloud()).toBe(true);
}
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
const nonCloudFeatured = FEATURED_MODELS.filter(
(m) => !m.endsWith("cloud"),
);
for (let i = 0; i < nonCloudFeatured.length; i++) {
const model = merged[i + cloudModels.length];
expect(model.model).toBe(nonCloudFeatured[i]);
expect(model.isCloud()).toBe(false);
}
// Length should be exactly FEATURED_MODELS length
expect(merged.length).toBe(FEATURED_MODELS.length);
});
it("should sort models correctly", () => {
const localModels: Model[] = [
new Model({ model: "zephyr:latest" }),
new Model({ model: "alpha:latest" }),
new Model({ model: "gpt-oss:120b-cloud" }),
];
const merged = mergeModels(localModels);
// First verify cloud models are first and in FEATURED_MODELS order
const cloudModels = FEATURED_MODELS.filter((m) => m.endsWith("cloud"));
for (let i = 0; i < cloudModels.length; i++) {
expect(merged[i].model).toBe(cloudModels[i]);
expect(merged[i].isCloud()).toBe(true);
}
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
const nonCloudFeatured = FEATURED_MODELS.filter(
(m) => !m.endsWith("cloud"),
);
for (let i = 0; i < nonCloudFeatured.length; i++) {
const model = merged[i + cloudModels.length];
expect(model.model).toBe(nonCloudFeatured[i]);
expect(model.isCloud()).toBe(false);
}
// Non-featured local models should be at the end in alphabetical order
const featuredCount = FEATURED_MODELS.length;
expect(merged[featuredCount].model).toBe("alpha:latest");
expect(merged[featuredCount + 1].model).toBe("zephyr:latest");
});
});

View File

@@ -1,101 +0,0 @@
import { Model } from "@/gotypes";
// Featured models list (in priority order)
export const FEATURED_MODELS = [
"gpt-oss:120b-cloud",
"gpt-oss:20b-cloud",
"deepseek-v3.1:671b-cloud",
"qwen3-coder:480b-cloud",
"qwen3-vl:235b-cloud",
"minimax-m2:cloud",
"glm-4.6:cloud",
"gpt-oss:120b",
"gpt-oss:20b",
"gemma3:27b",
"gemma3:12b",
"gemma3:4b",
"gemma3:1b",
"deepseek-r1:8b",
"qwen3-coder:30b",
"qwen3-vl:30b",
"qwen3-vl:8b",
"qwen3-vl:4b",
"qwen3:30b",
"qwen3:8b",
"qwen3:4b",
];
function alphabeticalSort(a: Model, b: Model): number {
return a.model.toLowerCase().localeCompare(b.model.toLowerCase());
}
//Merges models, sorting cloud models first, then other models
export function mergeModels(
localModels: Model[],
airplaneMode: boolean = false,
): Model[] {
const allModels = (localModels || []).map((model) => model);
// 1. Get cloud models from local models and featured list
const cloudModels = [...allModels.filter((m) => m.isCloud())];
// Add any cloud models from FEATURED_MODELS that aren't in local models
FEATURED_MODELS.filter((f) => f.endsWith("cloud")).forEach((cloudModel) => {
if (!cloudModels.some((m) => m.model === cloudModel)) {
cloudModels.push(new Model({ model: cloudModel }));
}
});
// 2. Get other featured models (non-cloud)
const featuredModels = FEATURED_MODELS.filter(
(f) => !f.endsWith("cloud"),
).map((model) => {
// Check if this model exists in local models
const localMatch = allModels.find(
(m) => m.model.toLowerCase() === model.toLowerCase(),
);
if (localMatch) return localMatch;
return new Model({
model,
});
});
// 3. Get remaining local models that aren't featured and aren't cloud models
const remainingModels = allModels.filter(
(model) =>
!model.isCloud() &&
!FEATURED_MODELS.some(
(f) => f.toLowerCase() === model.model.toLowerCase(),
),
);
cloudModels.sort((a, b) => {
const aIndex = FEATURED_MODELS.indexOf(a.model);
const bIndex = FEATURED_MODELS.indexOf(b.model);
// If both are featured, sort by their position in FEATURED_MODELS
if (aIndex !== -1 && bIndex !== -1) {
return aIndex - bIndex;
}
// If only one is featured, featured model comes first
if (aIndex !== -1 && bIndex === -1) return -1;
if (aIndex === -1 && bIndex !== -1) return 1;
// If neither is featured, sort alphabetically
return a.model.toLowerCase().localeCompare(b.model.toLowerCase());
});
featuredModels.sort(
(a, b) =>
FEATURED_MODELS.indexOf(a.model) - FEATURED_MODELS.indexOf(b.model),
);
remainingModels.sort(alphabeticalSort);
return airplaneMode
? [...featuredModels, ...remainingModels]
: [...cloudModels, ...featuredModels, ...remainingModels];
}

View File

@@ -188,8 +188,6 @@ func LogLevel() slog.Level {
var (
// FlashAttention enables the experimental flash attention feature.
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
// DebugLogRequests logs inference requests to disk for replay/debugging.
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
// KvCacheType is the quantization type for the K/V cache.
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
// NoHistory disables readline history.
@@ -275,27 +273,26 @@ type EnvVar struct {
func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
// Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

View File

@@ -1,144 +0,0 @@
package server
import (
"bytes"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/envconfig"
)
type inferenceRequestLogger struct {
dir string
counter uint64
}
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
if err != nil {
return nil, err
}
return &inferenceRequestLogger{dir: dir}, nil
}
func (s *Server) initRequestLogging() error {
if !envconfig.DebugLogRequests() {
return nil
}
requestLogger, err := newInferenceRequestLogger()
if err != nil {
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
}
s.requestLogger = requestLogger
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
return nil
}
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
if s.requestLogger == nil {
return handlers
}
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
}
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request == nil {
c.Next()
return
}
method := c.Request.Method
host := c.Request.Host
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
contentType := c.GetHeader("Content-Type")
var body []byte
if c.Request.Body != nil {
var err error
body, err = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err != nil {
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
}
}
c.Next()
l.log(route, method, scheme, host, contentType, body)
}
}
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
if l == nil || l.dir == "" {
return
}
if contentType == "" {
contentType = "application/json"
}
if host == "" || scheme == "" {
base := envconfig.Host()
if host == "" {
host = base.Host
}
if scheme == "" {
scheme = base.Scheme
}
}
routeForFilename := sanitizeRouteForFilename(route)
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
bodyPath := filepath.Join(l.dir, bodyFilename)
curlPath := filepath.Join(l.dir, curlFilename)
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
slog.Warn("failed to write debug request body", "route", route, "error", err)
return
}
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
return
}
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
}
func sanitizeRouteForFilename(route string) string {
route = strings.TrimPrefix(route, "/")
if route == "" {
return "root"
}
var b strings.Builder
b.Grow(len(route))
for _, r := range route {
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
b.WriteRune(r)
} else {
b.WriteByte('_')
}
}
return b.String()
}

View File

@@ -81,7 +81,6 @@ type Server struct {
addr net.Addr
sched *Scheduler
defaultNumCtx int
requestLogger *inferenceRequestLogger
}
func init() {
@@ -1584,24 +1583,24 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)...)
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)...)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)...)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
if rc != nil {
// wrap old with new
@@ -1651,9 +1650,6 @@ func Serve(ln net.Listener) error {
}
s := &Server{addr: ln.Addr()}
if err := s.initRequestLogging(); err != nil {
return err
}
var rc *ollama.Registry
if useClient2 {

View File

@@ -1,128 +0,0 @@
package server
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
gin.SetMode(gin.TestMode)
logDir := t.TempDir()
requestLogger := &inferenceRequestLogger{dir: logDir}
const route = "/v1/chat/completions"
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
var bodySeenByHandler string
r := gin.New()
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
t.Fatalf("failed to read body in handler: %v", err)
}
bodySeenByHandler = string(body)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
req.Host = "127.0.0.1:11434"
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if bodySeenByHandler != requestBody {
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
}
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
if err != nil {
t.Fatalf("failed to glob body logs: %v", err)
}
if len(bodyFiles) != 1 {
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
}
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
if err != nil {
t.Fatalf("failed to glob curl logs: %v", err)
}
if len(curlFiles) != 1 {
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
}
bodyData, err := os.ReadFile(bodyFiles[0])
if err != nil {
t.Fatalf("failed to read body log: %v", err)
}
if string(bodyData) != requestBody {
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
}
curlData, err := os.ReadFile(curlFiles[0])
if err != nil {
t.Fatalf("failed to read curl log: %v", err)
}
curlString := string(curlData)
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
}
bodyFileName := filepath.Base(bodyFiles[0])
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
}
}
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
requestLogger, err := newInferenceRequestLogger()
if err != nil {
t.Fatalf("expected no error creating request logger: %v", err)
}
t.Cleanup(func() {
_ = os.RemoveAll(requestLogger.dir)
})
if requestLogger == nil || requestLogger.dir == "" {
t.Fatalf("expected request logger directory to be set")
}
info, err := os.Stat(requestLogger.dir)
if err != nil {
t.Fatalf("expected directory to exist: %v", err)
}
if !info.IsDir() {
t.Fatalf("expected %q to be a directory", requestLogger.dir)
}
}
func TestSanitizeRouteForFilename(t *testing.T) {
tests := []struct {
route string
want string
}{
{route: "/api/generate", want: "api_generate"},
{route: "/v1/chat/completions", want: "v1_chat_completions"},
{route: "/v1/messages", want: "v1_messages"},
}
for _, tt := range tests {
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
}
}
}

View File

@@ -417,9 +417,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
numParallel = 1
}
// Some architectures are not safe with num_parallel > 1.
// `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}