Compare commits

..

2 Commits

Author SHA1 Message Date
Parth Sareen
6b2abfb433 server: add tests and fix isHuggingFaceURL edge case
- Add comprehensive tests for isHuggingFaceURL and getNumDownloadParts
- Fix bug where domains ending in huggingface.co (like nothuggingface.co)
  would incorrectly match as HuggingFace URLs
- Improve code comments with more detailed documentation
2026-01-18 16:45:17 -08:00
Parth Sareen
805ed4644c server: reduce download concurrency for HuggingFace URLs
Reduces concurrent download parts from 16 to 4 for HuggingFace URLs
to avoid triggering rate limits (HTTP 429 errors).

Adds OLLAMA_HF_CONCURRENCY environment variable for users who want
to customize the concurrency level.

Fixes #13297
2026-01-18 16:38:49 -08:00
62 changed files with 1029 additions and 1611 deletions

View File

@@ -190,7 +190,7 @@ if(MLX_ENGINE)
install(TARGETS mlx mlxc install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library ## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library") Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Here are some example models that can be downloaded: Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` | | Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` | | Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE] > [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,38 +260,6 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2 ./ollama run llama3.2
``` ```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
```shell
go build -tags mlx -o ollama-mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API ## REST API
Ollama has a REST API for running and managing models. Ollama has a REST API for running and managing models.
@@ -322,7 +290,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop ### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui) - [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat) - [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) - [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
@@ -526,7 +493,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database ### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector) - [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md) - [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps) - [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama) - [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases) - [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -669,7 +636,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability ### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama. - [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing. - [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. - [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -678,5 +644,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications. - [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security ### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server) - [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -253,8 +253,6 @@ func main() {
done <- osrv.Run(octx) done <- osrv.Run(octx)
}() }()
upd := &updater.Updater{Store: st}
uiServer := ui.Server{ uiServer := ui.Server{
Token: token, Token: token,
Restart: func() { Restart: func() {
@@ -269,10 +267,6 @@ func main() {
ToolRegistry: toolRegistry, ToolRegistry: toolRegistry,
Dev: devMode, Dev: devMode,
Logger: slog.Default(), Logger: slog.Default(),
Updater: upd,
UpdateAvailableFunc: func() {
UpdateAvailable("")
},
} }
srv := &http.Server{ srv := &http.Server{
@@ -290,13 +284,8 @@ func main() {
slog.Debug("background desktop server done") slog.Debug("background desktop server done")
}() }()
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable) updater := &updater.Updater{Store: st}
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun() hasCompletedFirstRun, err := st.HasCompletedFirstRun()
if err != nil { if err != nil {
@@ -359,18 +348,6 @@ func startHiddenTasks() {
// CLI triggered app startup use-case // CLI triggered app startup use-case
slog.Info("deferring pending update for fast startup") slog.Info("deferring pending update for fast startup")
} else { } else {
// Check if auto-update is enabled before automatically upgrading
st := &store.Store{}
settings, err := st.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
// Still show tray notification so user knows update is ready
UpdateAvailable("")
return
}
if err := updater.DoUpgradeAtStartup(); err != nil { if err := updater.DoUpgradeAtStartup(); err != nil {
slog.Info("unable to perform upgrade at startup", "error", err) slog.Info("unable to perform upgrade at startup", "error", err)
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization // Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization

View File

@@ -14,7 +14,6 @@ extern NSString *SystemWidePath;
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate> @interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
@property(strong, nonatomic) NSStatusItem *statusItem; @property(strong, nonatomic) NSStatusItem *statusItem;
@property(assign, nonatomic) BOOL updateAvailable; @property(assign, nonatomic) BOOL updateAvailable;
@property(assign, nonatomic) BOOL systemShutdownInProgress;
@end @end
@implementation AppDelegate @implementation AppDelegate
@@ -41,13 +40,6 @@ bool firstTimeRun,startHidden; // Set in run before initialization
} }
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification { - (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
// Register for system shutdown/restart notification so we can allow termination
[[[NSWorkspace sharedWorkspace] notificationCenter]
addObserver:self
selector:@selector(systemWillPowerOff:)
name:NSWorkspaceWillPowerOffNotification
object:nil];
// if we're in development mode, set the app icon // if we're in development mode, set the app icon
NSString *bundlePath = [[NSBundle mainBundle] bundlePath]; NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
if (![bundlePath hasSuffix:@".app"]) { if (![bundlePath hasSuffix:@".app"]) {
@@ -286,18 +278,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
[NSApp activateIgnoringOtherApps:YES]; [NSApp activateIgnoringOtherApps:YES];
} }
- (void)systemWillPowerOff:(NSNotification *)notification {
// Set flag so applicationShouldTerminate: knows to allow termination.
// The system will call applicationShouldTerminate: after posting this notification.
self.systemShutdownInProgress = YES;
}
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender { - (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
// Allow termination if the system is shutting down or restarting
if (self.systemShutdownInProgress) {
return NSTerminateNow;
}
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
[NSApp hide:nil]; [NSApp hide:nil];
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory]; [NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
return NSTerminateCancel; return NSTerminateCancel;

View File

@@ -9,12 +9,12 @@ import (
"strings" "strings"
"time" "time"
_ "github.com/mattn/go-sqlite3" sqlite3 "github.com/mattn/go-sqlite3"
) )
// currentSchemaVersion defines the current database schema version. // currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations. // Increment this when making schema changes that require migrations.
const currentSchemaVersion = 13 const currentSchemaVersion = 12
// database wraps the SQLite connection. // database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access: // SQLite handles its own locking for concurrent access:
@@ -85,7 +85,6 @@ func (db *database) init() error {
think_enabled BOOLEAN NOT NULL DEFAULT 0, think_enabled BOOLEAN NOT NULL DEFAULT 0,
think_level TEXT NOT NULL DEFAULT '', think_level TEXT NOT NULL DEFAULT '',
remote TEXT NOT NULL DEFAULT '', -- deprecated remote TEXT NOT NULL DEFAULT '', -- deprecated
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
schema_version INTEGER NOT NULL DEFAULT %d schema_version INTEGER NOT NULL DEFAULT %d
); );
@@ -245,12 +244,6 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v11 to v12: %w", err) return fmt.Errorf("migrate v11 to v12: %w", err)
} }
version = 12 version = 12
case 12:
// add auto_update_enabled column to settings table
if err := db.migrateV12ToV13(); err != nil {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
default: default:
// If we have a version we don't recognize, just set it to current // If we have a version we don't recognize, just set it to current
// This might happen during development // This might happen during development
@@ -459,21 +452,6 @@ func (db *database) migrateV11ToV12() error {
return nil return nil
} }
// migrateV12ToV13 adds the auto_update_enabled column to the settings table
func (db *database) migrateV12ToV13() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add auto_update_enabled column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug // cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error { func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
@@ -504,11 +482,19 @@ func (db *database) cleanupOrphanedData() error {
} }
func duplicateColumnError(err error) bool { func duplicateColumnError(err error) bool {
return err != nil && strings.Contains(err.Error(), "duplicate column name") if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "duplicate column name")
}
return false
} }
func columnNotExists(err error) bool { func columnNotExists(err error) bool {
return err != nil && strings.Contains(err.Error(), "no such column") if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "no such column")
}
return false
} }
func (db *database) getAllChats() ([]Chat, error) { func (db *database) getAllChats() ([]Chat, error) {
@@ -1122,9 +1108,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings var s Settings
err := db.conn.QueryRow(` err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
FROM settings FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled) `).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
if err != nil { if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err) return Settings{}, fmt.Errorf("get settings: %w", err)
} }
@@ -1135,8 +1121,8 @@ func (db *database) getSettings() (Settings, error) {
func (db *database) setSettings(s Settings) error { func (db *database) setSettings(s Settings) error {
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
UPDATE settings UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ? SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled) `, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
if err != nil { if err != nil {
return fmt.Errorf("set settings: %w", err) return fmt.Errorf("set settings: %w", err)
} }

View File

@@ -169,9 +169,6 @@ type Settings struct {
// SidebarOpen indicates if the chat sidebar is open // SidebarOpen indicates if the chat sidebar is open
SidebarOpen bool SidebarOpen bool
// AutoUpdateEnabled indicates if automatic updates should be downloaded
AutoUpdateEnabled bool
} }
type Store struct { type Store struct {

View File

@@ -413,7 +413,6 @@ export class Settings {
ThinkLevel: string; ThinkLevel: string;
SelectedModel: string; SelectedModel: string;
SidebarOpen: boolean; SidebarOpen: boolean;
AutoUpdateEnabled: boolean;
constructor(source: any = {}) { constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source); if ('string' === typeof source) source = JSON.parse(source);
@@ -432,7 +431,6 @@ export class Settings {
this.ThinkLevel = source["ThinkLevel"]; this.ThinkLevel = source["ThinkLevel"];
this.SelectedModel = source["SelectedModel"]; this.SelectedModel = source["SelectedModel"];
this.SidebarOpen = source["SidebarOpen"]; this.SidebarOpen = source["SidebarOpen"];
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
} }
} }
export class SettingsResponse { export class SettingsResponse {
@@ -469,46 +467,6 @@ export class HealthResponse {
this.healthy = source["healthy"]; this.healthy = source["healthy"];
} }
} }
export class UpdateInfo {
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.currentVersion = source["currentVersion"];
this.availableVersion = source["availableVersion"];
this.updateAvailable = source["updateAvailable"];
this.updateDownloaded = source["updateDownloaded"];
}
}
export class UpdateCheckResponse {
updateInfo: UpdateInfo;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.updateInfo = this.convertValues(source["updateInfo"], UpdateInfo);
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
if (!a) {
return a;
}
if (Array.isArray(a)) {
return (a as any[]).map(elem => this.convertValues(elem, classs));
} else if ("object" === typeof a) {
if (asMap) {
for (const key of Object.keys(a)) {
a[key] = new classs(a[key]);
}
return a;
}
return new classs(a);
}
return a;
}
}
export class User { export class User {
id: string; id: string;
email: string; email: string;

View File

@@ -414,54 +414,3 @@ export async function fetchHealth(): Promise<boolean> {
return false; return false;
} }
} }
export async function getCurrentVersion(): Promise<string> {
try {
const response = await fetch(`${API_BASE}/api/version`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.ok) {
const data = await response.json();
return data.version || "Unknown";
}
return "Unknown";
} catch (error) {
console.error("Error fetching version:", error);
return "Unknown";
}
}
export async function checkForUpdate(): Promise<{
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
}> {
const response = await fetch(`${API_BASE}/api/v1/update/check`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
throw new Error("Failed to check for update");
}
const data = await response.json();
return data.updateInfo;
}
export async function installUpdate(): Promise<void> {
const response = await fetch(`${API_BASE}/api/v1/update/install`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
const error = await response.text();
throw new Error(error || "Failed to install update");
}
}

View File

@@ -14,13 +14,12 @@ import {
XMarkIcon, XMarkIcon,
CogIcon, CogIcon,
ArrowLeftIcon, ArrowLeftIcon,
ArrowDownTrayIcon,
} from "@heroicons/react/20/solid"; } from "@heroicons/react/20/solid";
import { Settings as SettingsType } from "@/gotypes"; import { Settings as SettingsType } from "@/gotypes";
import { useNavigate } from "@tanstack/react-router"; import { useNavigate } from "@tanstack/react-router";
import { useUser } from "@/hooks/useUser"; import { useUser } from "@/hooks/useUser";
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { getSettings, updateSettings, checkForUpdate } from "@/api"; import { getSettings, updateSettings } from "@/api";
function AnimatedDots() { function AnimatedDots() {
return ( return (
@@ -40,12 +39,6 @@ export default function Settings() {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const [showSaved, setShowSaved] = useState(false); const [showSaved, setShowSaved] = useState(false);
const [restartMessage, setRestartMessage] = useState(false); const [restartMessage, setRestartMessage] = useState(false);
const [updateInfo, setUpdateInfo] = useState<{
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
} | null>(null);
const { const {
user, user,
isAuthenticated, isAuthenticated,
@@ -83,10 +76,6 @@ export default function Settings() {
useEffect(() => { useEffect(() => {
refetchUser(); refetchUser();
// Check for updates on mount
checkForUpdate()
.then(setUpdateInfo)
.catch((err) => console.error("Error checking for update:", err));
}, []); // eslint-disable-line react-hooks/exhaustive-deps }, []); // eslint-disable-line react-hooks/exhaustive-deps
useEffect(() => { useEffect(() => {
@@ -355,58 +344,6 @@ export default function Settings() {
{/* Local Configuration */} {/* Local Configuration */}
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800"> <div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
<div className="space-y-4 p-4"> <div className="space-y-4 p-4">
{/* Auto Update */}
<Field>
<div className="flex items-start justify-between gap-4">
<div className="flex items-start space-x-3 flex-1">
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
<div className="flex-1">
<Label>Auto-download updates</Label>
<Description>
{settings.AutoUpdateEnabled ? (
<>
Automatically downloads updates when available.
<div className="mt-2 text-xs text-zinc-600 dark:text-zinc-400">
Current version: {updateInfo?.currentVersion || "Loading..."}
</div>
</>
) : (
<>
Manually download updates.
<div className="mt-3 p-3 bg-zinc-50 dark:bg-zinc-900 rounded-lg border border-zinc-200 dark:border-zinc-800">
<div className="space-y-2 text-sm">
<div className="flex justify-between">
<span className="text-zinc-600 dark:text-zinc-400">Current version: {updateInfo?.currentVersion || "Loading..."}</span>
</div>
{updateInfo?.availableVersion && (
<div className="flex justify-between">
<span className="text-zinc-600 dark:text-zinc-400">Available version: {updateInfo?.availableVersion}</span>
</div>
)}
</div>
<a
href="https://ollama.com/download"
target="_blank"
rel="noopener noreferrer"
className="mt-3 inline-block text-sm text-neutral-600 dark:text-neutral-400 underline"
>
Download new version
</a>
</div>
</>
)}
</Description>
</div>
</div>
<div className="flex-shrink-0">
<Switch
checked={settings.AutoUpdateEnabled}
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
/>
</div>
</div>
</Field>
{/* Expose Ollama */} {/* Expose Ollama */}
<Field> <Field>
<div className="flex items-start justify-between gap-4"> <div className="flex items-start justify-between gap-4">

View File

@@ -100,17 +100,6 @@ type HealthResponse struct {
Healthy bool `json:"healthy"` Healthy bool `json:"healthy"`
} }
type UpdateInfo struct {
CurrentVersion string `json:"currentVersion"`
AvailableVersion string `json:"availableVersion"`
UpdateAvailable bool `json:"updateAvailable"`
UpdateDownloaded bool `json:"updateDownloaded"`
}
type UpdateCheckResponse struct {
UpdateInfo UpdateInfo `json:"updateInfo"`
}
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`

View File

@@ -28,7 +28,6 @@ import (
"github.com/ollama/ollama/app/tools" "github.com/ollama/ollama/app/tools"
"github.com/ollama/ollama/app/types/not" "github.com/ollama/ollama/app/types/not"
"github.com/ollama/ollama/app/ui/responses" "github.com/ollama/ollama/app/ui/responses"
"github.com/ollama/ollama/app/updater"
"github.com/ollama/ollama/app/version" "github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth" ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@@ -107,18 +106,6 @@ type Server struct {
// Dev is true if the server is running in development mode // Dev is true if the server is running in development mode
Dev bool Dev bool
// Updater for checking and downloading updates
Updater UpdaterInterface
UpdateAvailableFunc func()
}
// UpdaterInterface defines the methods we need from the updater
type UpdaterInterface interface {
CheckForUpdate(ctx context.Context) (bool, string, error)
InstallAndRestart() error
CancelOngoingDownload()
TriggerImmediateCheck()
} }
func (s *Server) log() *slog.Logger { func (s *Server) log() *slog.Logger {
@@ -297,8 +284,6 @@ func (s *Server) Handler() http.Handler {
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream)) mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
mux.Handle("GET /api/v1/settings", handle(s.getSettings)) mux.Handle("GET /api/v1/settings", handle(s.getSettings))
mux.Handle("POST /api/v1/settings", handle(s.settings)) mux.Handle("POST /api/v1/settings", handle(s.settings))
mux.Handle("GET /api/v1/update/check", handle(s.checkForUpdate))
mux.Handle("POST /api/v1/update/install", handle(s.installUpdate))
// Ollama proxy endpoints // Ollama proxy endpoints
ollamaProxy := s.ollamaProxy() ollamaProxy := s.ollamaProxy()
@@ -1463,24 +1448,6 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("failed to save settings: %w", err) return fmt.Errorf("failed to save settings: %w", err)
} }
// Handle auto-update toggle changes
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
if !settings.AutoUpdateEnabled {
// Auto-update disabled: cancel any ongoing download
if s.Updater != nil {
s.Updater.CancelOngoingDownload()
}
} else {
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
s.UpdateAvailableFunc()
} else if s.Updater != nil {
// Trigger the background checker to run immediately
s.Updater.TriggerImmediateCheck()
}
}
}
if old.ContextLength != settings.ContextLength || if old.ContextLength != settings.ContextLength ||
old.Models != settings.Models || old.Models != settings.Models ||
old.Expose != settings.Expose { old.Expose != settings.Expose {
@@ -1557,73 +1524,6 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
return json.NewEncoder(w).Encode(response) return json.NewEncoder(w).Encode(response)
} }
func (s *Server) checkForUpdate(w http.ResponseWriter, r *http.Request) error {
currentVersion := version.Version
if s.Updater == nil {
return fmt.Errorf("updater not available")
}
updateAvailable, updateVersion, err := s.Updater.CheckForUpdate(r.Context())
if err != nil {
s.log().Warn("failed to check for update", "error", err)
// Don't return error, just log it and continue with no update available
}
response := responses.UpdateCheckResponse{
UpdateInfo: responses.UpdateInfo{
CurrentVersion: currentVersion,
AvailableVersion: updateVersion,
UpdateAvailable: updateAvailable,
UpdateDownloaded: updater.UpdateDownloaded,
},
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
func (s *Server) installUpdate(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return fmt.Errorf("method not allowed")
}
if s.Updater == nil {
s.log().Error("install failed: updater not available")
return fmt.Errorf("updater not available")
}
// Check if update is downloaded
if !updater.UpdateDownloaded {
s.log().Error("install failed: no update downloaded")
return fmt.Errorf("no update downloaded")
}
// Send response before restarting
response := map[string]any{
"success": true,
"message": "Installing update and restarting...",
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
// Give the response time to be sent
time.Sleep(500 * time.Millisecond)
// Trigger the upgrade and restart
go func() {
time.Sleep(500 * time.Millisecond)
if err := s.Updater.InstallAndRestart(); err != nil {
s.log().Error("failed to install update", "error", err)
}
}()
return nil
}
func userAgent() string { func userAgent() string {
buildinfo, _ := debug.ReadBuildInfo() buildinfo, _ := debug.ReadBuildInfo()

View File

@@ -19,7 +19,6 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/ollama/ollama/app/store" "github.com/ollama/ollama/app/store"
@@ -59,8 +58,7 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
query := requestURL.Query() query := requestURL.Query()
query.Add("os", runtime.GOOS) query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH) query.Add("arch", runtime.GOARCH)
currentVersion := version.Version query.Add("version", version.Version)
query.Add("version", currentVersion)
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10)) query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
// The original macOS app used to use the device ID // The original macOS app used to use the device ID
@@ -133,27 +131,15 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
} }
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
// Create a cancellable context for this download
downloadCtx, cancel := context.WithCancel(ctx)
u.cancelDownloadLock.Lock()
u.cancelDownload = cancel
u.cancelDownloadLock.Unlock()
defer func() {
u.cancelDownloadLock.Lock()
u.cancelDownload = nil
u.cancelDownloadLock.Unlock()
cancel()
}()
// Do a head first to check etag info // Do a head first to check etag info
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil) req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
if err != nil { if err != nil {
return err return err
} }
// In case of slow downloads, continue the update check in the background // In case of slow downloads, continue the update check in the background
bgctx, bgcancel := context.WithCancel(downloadCtx) bgctx, cancel := context.WithCancel(ctx)
defer bgcancel() defer cancel()
go func() { go func() {
for { for {
select { select {
@@ -190,7 +176,6 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
_, err = os.Stat(stageFilename) _, err = os.Stat(stageFilename)
if err == nil { if err == nil {
slog.Info("update already downloaded", "bundle", stageFilename) slog.Info("update already downloaded", "bundle", stageFilename)
UpdateDownloaded = true
return nil return nil
} }
@@ -259,95 +244,34 @@ func cleanupOldDownloads(stageDir string) {
} }
type Updater struct { type Updater struct {
Store *store.Store Store *store.Store
cancelDownload context.CancelFunc
cancelDownloadLock sync.Mutex
checkNow chan struct{}
}
// CancelOngoingDownload cancels any currently running download
func (u *Updater) CancelOngoingDownload() {
u.cancelDownloadLock.Lock()
defer u.cancelDownloadLock.Unlock()
if u.cancelDownload != nil {
slog.Info("cancelling ongoing update download")
u.cancelDownload()
u.cancelDownload = nil
}
}
// TriggerImmediateCheck signals the background checker to check for updates immediately
func (u *Updater) TriggerImmediateCheck() {
if u.checkNow != nil {
u.checkNow <- struct{}{}
}
} }
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) { func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
go func() { go func() {
// Don't blast an update message immediately after startup // Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay) time.Sleep(UpdateCheckInitialDelay)
slog.Info("beginning update checker", "interval", UpdateCheckInterval) slog.Info("beginning update checker", "interval", UpdateCheckInterval)
ticker := time.NewTicker(UpdateCheckInterval)
defer ticker.Stop()
for { for {
available, resp := u.checkForUpdate(ctx)
if available {
err := u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
} else {
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
}
}
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
slog.Debug("stopping background update checker") slog.Debug("stopping background update checker")
return return
case <-u.checkNow: default:
// Immediate check triggered time.Sleep(UpdateCheckInterval)
case <-ticker.C:
// Regular interval check
}
// Always check for updates
available, resp := u.checkForUpdate(ctx)
if !available {
continue
}
// Update is available - check if auto-update is enabled for downloading
settings, err := u.Store.Settings()
if err != nil {
slog.Error("failed to load settings", "error", err)
continue
}
if !settings.AutoUpdateEnabled {
// Auto-update disabled - don't download, just log
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
continue
}
// Auto-update is enabled - download
err = u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error("failed to download new release", "error", err)
continue
}
// Download successful - show tray notification (regardless of toggle state)
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)
} }
} }
}() }()
} }
func (u *Updater) CheckForUpdate(ctx context.Context) (bool, string, error) {
available, resp := u.checkForUpdate(ctx)
return available, resp.UpdateVersion, nil
}
func (u *Updater) InstallAndRestart() error {
if !UpdateDownloaded {
return fmt.Errorf("no update downloaded")
}
slog.Info("installing update and restarting")
return DoUpgrade(true)
}

View File

@@ -11,7 +11,6 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -86,17 +85,7 @@ func TestBackgoundChecker(t *testing.T) {
UpdateCheckURLBase = server.URL + "/update.json" UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}} updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close() defer updater.Store.Close() // Ensure database is closed
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
updater.StartBackgroundUpdaterChecker(ctx, cb) updater.StartBackgroundUpdaterChecker(ctx, cb)
select { select {
case <-stallTimer.C: case <-stallTimer.C:
@@ -110,187 +99,3 @@ func TestBackgoundChecker(t *testing.T) {
} }
} }
} }
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
var downloadAttempted atomic.Bool
done := make(chan struct{})
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
UpdateCheckInitialDelay = 5 * time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
downloadAttempted.Store(true)
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
// Ensure auto-update is disabled
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
cb := func(ver string) error {
t.Fatal("callback should not be called when auto-update is disabled")
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait enough time for multiple check cycles
time.Sleep(50 * time.Millisecond)
close(done)
if downloadAttempted.Load() {
t.Fatal("download should not be attempted when auto-update is disabled")
}
}
func TestCancelOngoingDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
downloadStarted := make(chan struct{})
downloadCancelled := make(chan struct{})
ctx := t.Context()
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", "1000000")
w.WriteHeader(http.StatusOK)
return
}
// Signal that download has started
close(downloadStarted)
// Wait for cancellation or timeout
select {
case <-r.Context().Done():
close(downloadCancelled)
return
case <-time.After(5 * time.Second):
t.Error("download was not cancelled in time")
}
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
_, resp := updater.checkForUpdate(ctx)
// Start download in goroutine
go func() {
_ = updater.DownloadNewRelease(ctx, resp)
}()
// Wait for download to start
select {
case <-downloadStarted:
case <-time.After(2 * time.Second):
t.Fatal("download did not start in time")
}
// Cancel the download
updater.CancelOngoingDownload()
// Verify cancellation was received
select {
case <-downloadCancelled:
// Success
case <-time.After(2 * time.Second):
t.Fatal("download cancellation was not received by server")
}
}
func TestTriggerImmediateCheck(t *testing.T) {
UpdateStageDir = t.TempDir()
checkCount := atomic.Int32{}
checkDone := make(chan struct{}, 10)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Set a very long interval so only TriggerImmediateCheck causes checks
UpdateCheckInitialDelay = 1 * time.Millisecond
UpdateCheckInterval = 1 * time.Hour
VerifyDownload = func() error {
return nil
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
checkCount.Add(1)
select {
case checkDone <- struct{}{}:
default:
}
// Return no update available
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
cb := func(ver string) error {
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for goroutine to start and pass initial delay
time.Sleep(10 * time.Millisecond)
// With 1 hour interval, no check should have happened yet
initialCount := checkCount.Load()
// Trigger immediate check
updater.TriggerImmediateCheck()
// Wait for the triggered check
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("triggered check did not happen")
}
finalCount := checkCount.Load()
if finalCount <= initialCount {
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
}
}

View File

@@ -369,6 +369,25 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
return nil return nil
} }
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
// const ERROR_SUCCESS syscall.Errno = 0
// t.muMenus.RLock()
// menu := uintptr(t.menus[parentId])
// t.muMenus.RUnlock()
// res, _, err := pRemoveMenu.Call(
// menu,
// uintptr(menuItemId),
// MF_BYCOMMAND,
// )
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
// return err
// }
// t.delFromVisibleItems(parentId, menuItemId)
// return nil
// }
func (t *winTray) showMenu() error { func (t *winTray) showMenu() error {
p := point{} p := point{}
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p))) boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))

View File

@@ -51,6 +51,7 @@ const (
IMAGE_ICON = 1 // Loads an icon IMAGE_ICON = 1 // Loads an icon
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
MF_BYCOMMAND = 0x00000000
MFS_DISABLED = 0x00000003 MFS_DISABLED = 0x00000003
MFT_SEPARATOR = 0x00000800 MFT_SEPARATOR = 0x00000800
MFT_STRING = 0x00000000 MFT_STRING = 0x00000000

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
Placeholder: "Send a message (/? for help)", Placeholder: "Send a message (/? for help)",
AltPlaceholder: "Press Enter to send", AltPlaceholder: `Use """ to end multi-line input`,
}) })
if err != nil { if err != nil {
return err return err

View File

@@ -21,7 +21,6 @@ ollama pull glm-4.7:cloud
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables: To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored export ANTHROPIC_API_KEY=ollama # required but ignored
``` ```
@@ -248,13 +247,12 @@ curl -X POST http://localhost:11434/v1/messages \
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend: [Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell ```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
``` ```
Or set the environment variables in your shell profile: Or set the environment variables in your shell profile:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama export ANTHROPIC_API_KEY=ollama
``` ```

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama"; import { Ollama } from "ollama";
const client = new Ollama(); const client = new Ollama();
const results = await client.webSearch("what is ollama?"); const results = await client.webSearch({ query: "what is ollama?" });
console.log(JSON.stringify(results, null, 2)); console.log(JSON.stringify(results, null, 2));
``` ```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama"; import { Ollama } from "ollama";
const client = new Ollama(); const client = new Ollama();
const fetchResult = await client.webFetch("https://ollama.com"); const fetchResult = await client.webFetch({ url: "https://ollama.com" });
console.log(JSON.stringify(fetchResult, null, 2)); console.log(JSON.stringify(fetchResult, null, 2));
``` ```

View File

@@ -111,9 +111,7 @@
"/integrations/zed", "/integrations/zed",
"/integrations/roo-code", "/integrations/roo-code",
"/integrations/n8n", "/integrations/n8n",
"/integrations/xcode", "/integrations/xcode"
"/integrations/onyx",
"/integrations/marimo"
] ]
}, },
{ {

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size? ## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens. By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use: This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 174 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 186 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 300 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -25,7 +25,6 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables: 1. Set the environment variables:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama export ANTHROPIC_API_KEY=ollama
``` ```
@@ -39,7 +38,7 @@ claude --model qwen3-coder
Or run with environment variables inline: Or run with environment variables inline:
```shell ```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
``` ```
## Connecting to ollama.com ## Connecting to ollama.com

View File

@@ -1,73 +0,0 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -1,63 +0,0 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -1,5 +1,5 @@
--- ---
title: Linux title: "Linux"
--- ---
## Install ## Install
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install ## Manual install
<Note> <Note>
If you are upgrading from a prior version, you should remove the old libraries If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
with `sudo rm -rf /usr/lib/ollama` first.
</Note> </Note>
Download and extract the package: Download and extract the package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \ curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar x -C /usr | sudo tar zx -C /usr
``` ```
Start Ollama: Start Ollama:
@@ -41,8 +40,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package: If you have an AMD GPU, also download and extract the additional ROCm package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
| sudo tar x -C /usr | sudo tar zx -C /usr
``` ```
### ARM64 install ### ARM64 install
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
Download and extract the ARM64-specific package: Download and extract the ARM64-specific package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \ curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
| sudo tar x -C /usr | sudo tar zx -C /usr
``` ```
### Adding Ollama as a startup service (recommended) ### Adding Ollama as a startup service (recommended)
@@ -113,11 +112,7 @@ sudo systemctl status ollama
``` ```
<Note> <Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note> </Note>
## Customizing ## Customizing
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama: Or by re-downloading Ollama:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \ curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar x -C /usr | sudo tar zx -C /usr
``` ```
## Installing specific versions ## Installing specific versions

View File

@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather") t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
} }
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok { if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String()) t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
} }
case <-ctx.Done(): case <-ctx.Done():

View File

@@ -1464,11 +1464,6 @@ type CompletionRequest struct {
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20) // TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int TopLogprobs int
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Seed int64 `json:"seed,omitempty"`
} }
// DoneReason represents the reason why a completion response is done // DoneReason represents the reason why a completion response is done
@@ -1517,11 +1512,6 @@ type CompletionResponse struct {
// Logprobs contains log probability information if requested // Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"` Logprobs []Logprob `json:"logprobs,omitempty"`
// Image generation fields
Image []byte `json:"image,omitempty"` // Generated image
Step int `json:"step,omitempty"` // Current generation step
Total int `json:"total,omitempty"` // Total generation steps
} }
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {

View File

@@ -8,7 +8,6 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -442,7 +441,6 @@ type ResponsesWriter struct {
stream bool stream bool
responseID string responseID string
itemID string itemID string
request openai.ResponsesRequest
} }
func (w *ResponsesWriter) writeEvent(eventType string, data any) error { func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
@@ -480,9 +478,7 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
// Non-streaming response // Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json") w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request) response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
completedAt := time.Now().Unix()
response.CompletedAt = &completedAt
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response) return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
} }
@@ -527,12 +523,11 @@ func ResponsesMiddleware() gin.HandlerFunc {
w := &ResponsesWriter{ w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer}, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req), converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
model: req.Model, model: req.Model,
stream: streamRequested, stream: streamRequested,
responseID: responseID, responseID: responseID,
itemID: itemID, itemID: itemID,
request: req,
} }
// Set headers based on streaming mode // Set headers based on streaming mode

View File

@@ -630,10 +630,6 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
// decodeImageURL decodes a base64 data URI into raw image bytes. // decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) { func decodeImageURL(url string) (api.ImageData, error) {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
}
types := []string{"jpeg", "jpg", "png", "webp"} types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64 // Support blank mime type to match /api/chat's behavior of taking just unadorned base64

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand" "math/rand"
"time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -266,9 +265,9 @@ type ResponsesText struct {
type ResponsesTool struct { type ResponsesTool struct {
Type string `json:"type"` // "function" Type string `json:"type"` // "function"
Name string `json:"name"` Name string `json:"name"`
Description *string `json:"description"` // nullable but required Description string `json:"description,omitempty"`
Strict *bool `json:"strict"` // nullable but required Strict bool `json:"strict,omitempty"`
Parameters map[string]any `json:"parameters"` // nullable but required Parameters map[string]any `json:"parameters,omitempty"`
} }
type ResponsesRequest struct { type ResponsesRequest struct {
@@ -476,16 +475,11 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
} }
} }
var description string
if t.Description != nil {
description = *t.Description
}
return api.Tool{ return api.Tool{
Type: t.Type, Type: t.Type,
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: t.Name, Name: t.Name,
Description: description, Description: t.Description,
Parameters: params, Parameters: params,
}, },
}, nil }, nil
@@ -522,60 +516,17 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
// Response types for the Responses API // Response types for the Responses API
// ResponsesTextField represents the text output configuration in the response.
type ResponsesTextField struct {
Format ResponsesTextFormat `json:"format"`
}
// ResponsesReasoningOutput represents reasoning configuration in the response.
type ResponsesReasoningOutput struct {
Effort *string `json:"effort,omitempty"`
Summary *string `json:"summary,omitempty"`
}
// ResponsesError represents an error in the response.
type ResponsesError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// ResponsesIncompleteDetails represents details about why a response was incomplete.
type ResponsesIncompleteDetails struct {
Reason string `json:"reason"`
}
type ResponsesResponse struct { type ResponsesResponse struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
CompletedAt *int64 `json:"completed_at"` Status string `json:"status"`
Status string `json:"status"` Model string `json:"model"`
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"` Output []ResponsesOutputItem `json:"output"`
Model string `json:"model"` Usage *ResponsesUsage `json:"usage,omitempty"`
PreviousResponseID *string `json:"previous_response_id"` // TODO(drifkin): add `temperature` and `top_p` to the response, but this
Instructions *string `json:"instructions"` // requires additional plumbing to find the effective values since the
Output []ResponsesOutputItem `json:"output"` // defaults can come from the model or the request
Error *ResponsesError `json:"error"`
Tools []ResponsesTool `json:"tools"`
ToolChoice any `json:"tool_choice"`
Truncation string `json:"truncation"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
Text ResponsesTextField `json:"text"`
TopP float64 `json:"top_p"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
TopLogprobs int `json:"top_logprobs"`
Temperature float64 `json:"temperature"`
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
Usage *ResponsesUsage `json:"usage"`
MaxOutputTokens *int `json:"max_output_tokens"`
MaxToolCalls *int `json:"max_tool_calls"`
Store bool `json:"store"`
Background bool `json:"background"`
ServiceTier string `json:"service_tier"`
Metadata map[string]any `json:"metadata"`
SafetyIdentifier *string `json:"safety_identifier"`
PromptCacheKey *string `json:"prompt_cache_key"`
} }
type ResponsesOutputItem struct { type ResponsesOutputItem struct {
@@ -599,39 +550,18 @@ type ResponsesReasoningSummary struct {
} }
type ResponsesOutputContent struct { type ResponsesOutputContent struct {
Type string `json:"type"` // "output_text" Type string `json:"type"` // "output_text"
Text string `json:"text"` Text string `json:"text"`
Annotations []any `json:"annotations"`
Logprobs []any `json:"logprobs"`
}
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
}
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
} }
type ResponsesUsage struct { type ResponsesUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
} }
// derefFloat64 returns the value of a float64 pointer, or a default if nil. // ToResponse converts an api.ChatResponse to a Responses API response
func derefFloat64(p *float64, def float64) float64 { func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
if p != nil {
return *p
}
return def
}
// ToResponse converts an api.ChatResponse to a Responses API response.
// The request is used to echo back request parameters in the response.
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
var output []ResponsesOutputItem var output []ResponsesOutputItem
// Add reasoning item if thinking is present // Add reasoning item if thinking is present
@@ -655,7 +585,6 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
output = append(output, ResponsesOutputItem{ output = append(output, ResponsesOutputItem{
ID: fmt.Sprintf("fc_%s_%d", responseID, i), ID: fmt.Sprintf("fc_%s_%d", responseID, i),
Type: "function_call", Type: "function_call",
Status: "completed",
CallID: tc.ID, CallID: tc.ID,
Name: tc.Function.Name, Name: tc.Function.Name,
Arguments: tc.Function.Arguments, Arguments: tc.Function.Arguments,
@@ -669,90 +598,25 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
Role: "assistant", Role: "assistant",
Content: []ResponsesOutputContent{ Content: []ResponsesOutputContent{
{ {
Type: "output_text", Type: "output_text",
Text: chatResponse.Message.Content, Text: chatResponse.Message.Content,
Annotations: []any{},
Logprobs: []any{},
}, },
}, },
}) })
} }
var instructions *string
if request.Instructions != "" {
instructions = &request.Instructions
}
// Build truncation with default
truncation := "disabled"
if request.Truncation != nil {
truncation = *request.Truncation
}
tools := request.Tools
if tools == nil {
tools = []ResponsesTool{}
}
text := ResponsesTextField{
Format: ResponsesTextFormat{Type: "text"},
}
if request.Text != nil && request.Text.Format != nil {
text.Format = *request.Text.Format
}
// Build reasoning output from request
var reasoning *ResponsesReasoningOutput
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
reasoning = &ResponsesReasoningOutput{}
if request.Reasoning.Effort != "" {
reasoning.Effort = &request.Reasoning.Effort
}
if request.Reasoning.Summary != "" {
reasoning.Summary = &request.Reasoning.Summary
}
}
return ResponsesResponse{ return ResponsesResponse{
ID: responseID, ID: responseID,
Object: "response", Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(), CreatedAt: chatResponse.CreatedAt.Unix(),
CompletedAt: nil, // Set by middleware when writing final response Status: "completed",
Status: "completed", Model: model,
IncompleteDetails: nil, // Only populated if response incomplete Output: output,
Model: model,
PreviousResponseID: nil, // Not supported
Instructions: instructions,
Output: output,
Error: nil, // Only populated on failure
Tools: tools,
ToolChoice: "auto", // Default value
Truncation: truncation,
ParallelToolCalls: true, // Default value
Text: text,
TopP: derefFloat64(request.TopP, 1.0),
PresencePenalty: 0, // Default value
FrequencyPenalty: 0, // Default value
TopLogprobs: 0, // Default value
Temperature: derefFloat64(request.Temperature, 1.0),
Reasoning: reasoning,
Usage: &ResponsesUsage{ Usage: &ResponsesUsage{
InputTokens: chatResponse.PromptEvalCount, InputTokens: chatResponse.PromptEvalCount,
OutputTokens: chatResponse.EvalCount, OutputTokens: chatResponse.EvalCount,
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount, TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
// TODO(drifkin): wire through the actual values
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
// TODO(drifkin): wire through the actual values
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
}, },
MaxOutputTokens: request.MaxOutputTokens,
MaxToolCalls: nil, // Not supported
Store: false, // We don't store responses
Background: request.Background,
ServiceTier: "default", // Default value
Metadata: map[string]any{},
SafetyIdentifier: nil, // Not supported
PromptCacheKey: nil, // Not supported
} }
} }
@@ -772,7 +636,6 @@ type ResponsesStreamConverter struct {
responseID string responseID string
itemID string itemID string
model string model string
request ResponsesRequest
// State tracking (mutated across Process calls) // State tracking (mutated across Process calls)
firstWrite bool firstWrite bool
@@ -805,12 +668,11 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
} }
// NewResponsesStreamConverter creates a new converter with the given configuration. // NewResponsesStreamConverter creates a new converter with the given configuration.
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter { func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
return &ResponsesStreamConverter{ return &ResponsesStreamConverter{
responseID: responseID, responseID: responseID,
itemID: itemID, itemID: itemID,
model: model, model: model,
request: request,
firstWrite: true, firstWrite: true,
} }
} }
@@ -855,120 +717,25 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
return events return events
} }
// buildResponseObject creates a full response object with all required fields for streaming events.
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
var instructions any = nil
if c.request.Instructions != "" {
instructions = c.request.Instructions
}
truncation := "disabled"
if c.request.Truncation != nil {
truncation = *c.request.Truncation
}
var tools []any
if c.request.Tools != nil {
for _, t := range c.request.Tools {
tools = append(tools, map[string]any{
"type": t.Type,
"name": t.Name,
"description": t.Description,
"strict": t.Strict,
"parameters": t.Parameters,
})
}
}
if tools == nil {
tools = []any{}
}
textFormat := map[string]any{"type": "text"}
if c.request.Text != nil && c.request.Text.Format != nil {
textFormat = map[string]any{
"type": c.request.Text.Format.Type,
}
if c.request.Text.Format.Name != "" {
textFormat["name"] = c.request.Text.Format.Name
}
if c.request.Text.Format.Schema != nil {
textFormat["schema"] = c.request.Text.Format.Schema
}
if c.request.Text.Format.Strict != nil {
textFormat["strict"] = *c.request.Text.Format.Strict
}
}
var reasoning any = nil
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
r := map[string]any{}
if c.request.Reasoning.Effort != "" {
r["effort"] = c.request.Reasoning.Effort
} else {
r["effort"] = nil
}
if c.request.Reasoning.Summary != "" {
r["summary"] = c.request.Reasoning.Summary
} else {
r["summary"] = nil
}
reasoning = r
}
// Build top_p and temperature with defaults
topP := 1.0
if c.request.TopP != nil {
topP = *c.request.TopP
}
temperature := 1.0
if c.request.Temperature != nil {
temperature = *c.request.Temperature
}
return map[string]any{
"id": c.responseID,
"object": "response",
"created_at": time.Now().Unix(),
"completed_at": nil,
"status": status,
"incomplete_details": nil,
"model": c.model,
"previous_response_id": nil,
"instructions": instructions,
"output": output,
"error": nil,
"tools": tools,
"tool_choice": "auto",
"truncation": truncation,
"parallel_tool_calls": true,
"text": map[string]any{"format": textFormat},
"top_p": topP,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_logprobs": 0,
"temperature": temperature,
"reasoning": reasoning,
"usage": usage,
"max_output_tokens": c.request.MaxOutputTokens,
"max_tool_calls": nil,
"store": false,
"background": c.request.Background,
"service_tier": "default",
"metadata": map[string]any{},
"safety_identifier": nil,
"prompt_cache_key": nil,
}
}
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent { func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
return c.newEvent("response.created", map[string]any{ return c.newEvent("response.created", map[string]any{
"response": c.buildResponseObject("in_progress", []any{}, nil), "response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
}) })
} }
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent { func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
return c.newEvent("response.in_progress", map[string]any{ return c.newEvent("response.in_progress", map[string]any{
"response": c.buildResponseObject("in_progress", []any{}, nil), "response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
}) })
} }
@@ -995,10 +762,9 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
// Emit delta // Emit delta
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{ events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
"item_id": c.reasoningItemID, "item_id": c.reasoningItemID,
"output_index": c.outputIndex, "output_index": c.outputIndex,
"summary_index": 0, "delta": thinking,
"delta": thinking,
})) }))
// TODO(drifkin): consider adding // TODO(drifkin): consider adding
@@ -1017,10 +783,9 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
events := []ResponsesStreamEvent{ events := []ResponsesStreamEvent{
c.newEvent("response.reasoning_summary_text.done", map[string]any{ c.newEvent("response.reasoning_summary_text.done", map[string]any{
"item_id": c.reasoningItemID, "item_id": c.reasoningItemID,
"output_index": c.outputIndex, "output_index": c.outputIndex,
"summary_index": 0, "text": c.accumulatedThinking,
"text": c.accumulatedThinking,
}), }),
c.newEvent("response.output_item.done", map[string]any{ c.newEvent("response.output_item.done", map[string]any{
"output_index": c.outputIndex, "output_index": c.outputIndex,
@@ -1133,10 +898,8 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": c.contentIndex, "content_index": c.contentIndex,
"part": map[string]any{ "part": map[string]any{
"type": "output_text", "type": "output_text",
"text": "", "text": "",
"annotations": []any{},
"logprobs": []any{},
}, },
})) }))
} }
@@ -1150,7 +913,6 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"delta": content, "delta": content,
"logprobs": []any{},
})) }))
return events return events
@@ -1182,10 +944,8 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
"status": "completed", "status": "completed",
"role": "assistant", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}}, }},
}) })
} }
@@ -1207,7 +967,6 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"text": c.accumulatedText, "text": c.accumulatedText,
"logprobs": []any{},
})) }))
// response.content_part.done // response.content_part.done
@@ -1216,10 +975,8 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"part": map[string]any{ "part": map[string]any{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}, },
})) }))
@@ -1232,31 +989,26 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"status": "completed", "status": "completed",
"role": "assistant", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}}, }},
}, },
})) }))
} }
// response.completed // response.completed
usage := map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
"output_tokens_details": map[string]any{
"reasoning_tokens": 0,
},
}
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
response["completed_at"] = time.Now().Unix()
events = append(events, c.newEvent("response.completed", map[string]any{ events = append(events, c.newEvent("response.completed", map[string]any{
"response": response, "response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "completed",
"output": c.buildFinalOutput(),
"usage": map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
},
},
})) }))
return events return events

View File

@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
} }
func TestResponsesStreamConverter_TextOnly(t *testing.T) { func TestResponsesStreamConverter_TextOnly(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk with content // First chunk with content
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
} }
func TestResponsesStreamConverter_ToolCalls(t *testing.T) { func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{ Message: api.Message{
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
} }
func TestResponsesStreamConverter_Reasoning(t *testing.T) { func TestResponsesStreamConverter_Reasoning(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk with thinking // First chunk with thinking
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
Content: "The answer is 42", Content: "The answer is 42",
}, },
Done: true, Done: true,
}, ResponsesRequest{}) })
// Should have 2 output items: reasoning + message // Should have 2 output items: reasoning + message
if len(response.Output) != 2 { if len(response.Output) != 2 {
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) { func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
// Verify that response.output_item.done includes content field for messages // Verify that response.output_item.done includes content field for messages
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk // First chunk
converter.Process(api.ChatResponse{ converter.Process(api.ChatResponse{
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) { func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
// Verify that response.completed includes the output array // Verify that response.completed includes the output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// Process some content // Process some content
converter.Process(api.ChatResponse{ converter.Process(api.ChatResponse{
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) { func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
// Verify that response.created includes an empty output array // Verify that response.created includes an empty output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hi"}, Message: api.Message{Content: "Hi"},
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) { func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
// Verify that events include incrementing sequence numbers // Verify that events include incrementing sequence numbers
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hello"}, Message: api.Message{Content: "Hello"},
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) { func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
// Verify that function call items include status field // Verify that function call items include status field
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{}) converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{ Message: api.Message{

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
) )
type Prompt struct { type Prompt struct {
@@ -37,11 +36,10 @@ type Terminal struct {
} }
type Instance struct { type Instance struct {
Prompt *Prompt Prompt *Prompt
Terminal *Terminal Terminal *Terminal
History *History History *History
Pasting bool Pasting bool
pastedLines []string
} }
func New(prompt Prompt) (*Instance, error) { func New(prompt Prompt) (*Instance, error) {
@@ -176,8 +174,6 @@ func (i *Instance) Readline() (string, error) {
case CharEsc: case CharEsc:
esc = true esc = true
case CharInterrupt: case CharInterrupt:
i.pastedLines = nil
i.Prompt.UseAlt = false
return "", ErrInterrupt return "", ErrInterrupt
case CharPrev: case CharPrev:
i.historyPrev(buf, &currentLineBuf) i.historyPrev(buf, &currentLineBuf)
@@ -192,23 +188,7 @@ func (i *Instance) Readline() (string, error) {
case CharForward: case CharForward:
buf.MoveRight() buf.MoveRight()
case CharBackspace, CharCtrlH: case CharBackspace, CharCtrlH:
if buf.IsEmpty() && len(i.pastedLines) > 0 { buf.Remove()
lastIdx := len(i.pastedLines) - 1
prevLine := i.pastedLines[lastIdx]
i.pastedLines = i.pastedLines[:lastIdx]
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
if len(i.pastedLines) == 0 {
fmt.Print(i.Prompt.Prompt)
i.Prompt.UseAlt = false
} else {
fmt.Print(i.Prompt.AltPrompt)
}
for _, r := range prevLine {
buf.Add(r)
}
} else {
buf.Remove()
}
case CharTab: case CharTab:
// todo: convert back to real tabs // todo: convert back to real tabs
for range 8 { for range 8 {
@@ -231,28 +211,13 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ: case CharCtrlZ:
fd := os.Stdin.Fd() fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios) return handleCharCtrlZ(fd, i.Terminal.termios)
case CharCtrlJ: case CharEnter, CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
case CharEnter:
output := buf.String() output := buf.String()
if len(i.pastedLines) > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
if output != "" { if output != "" {
i.History.Add(output) i.History.Add(output)
} }
buf.MoveToEnd() buf.MoveToEnd()
fmt.Println() fmt.Println()
i.Prompt.UseAlt = false
return output, nil return output, nil
default: default:

View File

@@ -179,7 +179,7 @@ _build_macapp() {
fi fi
rm -f dist/Ollama-darwin.zip rm -f dist/Ollama-darwin.zip
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz (cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple # Notarize and Staple
@@ -187,7 +187,7 @@ _build_macapp() {
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID" $(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app $(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
rm -f dist/Ollama.dmg rm -f dist/Ollama.dmg

View File

@@ -50,17 +50,12 @@ func (r registryChallenge) URL() (*url.URL, error) {
return redirectURL, nil return redirectURL, nil
} }
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) { func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
redirectURL, err := challenge.URL() redirectURL, err := challenge.URL()
if err != nil { if err != nil {
return "", err return "", err
} }
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
if redirectURL.Host != originalHost {
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
}
sha256sum := sha256.Sum256(nil) sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))) data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))

View File

@@ -1,113 +0,0 @@
package server
import (
"context"
"strings"
"testing"
"time"
)
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
tests := []struct {
realm string
originalHost string
wantMismatch bool
}{
{"https://example.com/token", "example.com", false},
{"https://example.com/token", "other.com", true},
{"https://example.com/token", "localhost:8000", true},
{"https://localhost:5000/token", "localhost:5000", false},
{"https://localhost:5000/token", "localhost:6000", true},
}
for _, tt := range tests {
t.Run(tt.originalHost, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
if tt.wantMismatch && !isMismatch {
t.Errorf("expected domain mismatch error, got: %v", err)
}
if !tt.wantMismatch && isMismatch {
t.Errorf("unexpected domain mismatch error: %v", err)
}
})
}
}
func TestParseRegistryChallenge(t *testing.T) {
tests := []struct {
input string
wantRealm, wantService, wantScope string
}{
{
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
"https://auth.example.com/token", "registry", "repo:foo:pull",
},
{
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
"https://r.ollama.ai/v2/token", "ollama", "-",
},
{"", "", "", ""},
}
for _, tt := range tests {
result := parseRegistryChallenge(tt.input)
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
tt.input, result.Realm, result.Service, result.Scope,
tt.wantRealm, tt.wantService, tt.wantScope)
}
}
}
func TestRegistryChallengeURL(t *testing.T) {
challenge := registryChallenge{
Realm: "https://auth.example.com/token",
Service: "registry",
Scope: "repo:foo:pull repo:bar:push",
}
u, err := challenge.URL()
if err != nil {
t.Fatalf("URL() error: %v", err)
}
if u.Host != "auth.example.com" {
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
}
if u.Path != "/token" {
t.Errorf("path = %q, want %q", u.Path, "/token")
}
q := u.Query()
if q.Get("service") != "registry" {
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
}
if scopes := q["scope"]; len(scopes) != 2 {
t.Errorf("scope count = %d, want 2", len(scopes))
}
if q.Get("ts") == "" {
t.Error("missing ts")
}
if q.Get("nonce") == "" {
t.Error("missing nonce")
}
// Nonces should differ between calls
u2, _ := challenge.URL()
if q.Get("nonce") == u2.Query().Get("nonce") {
t.Error("nonce should be unique per call")
}
}
func TestRegistryChallengeURLInvalid(t *testing.T) {
challenge := registryChallenge{Realm: "://invalid"}
if _, err := challenge.URL(); err == nil {
t.Error("expected error for invalid URL")
}
}

View File

@@ -95,11 +95,48 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
} }
const ( const (
numDownloadParts = 16 // numDownloadParts is the default number of concurrent download parts for standard downloads
numDownloadParts = 16
// numHFDownloadParts is the reduced number of concurrent download parts for HuggingFace
// downloads to avoid triggering rate limits (HTTP 429 errors). See GitHub issue #13297.
numHFDownloadParts = 4
minDownloadPartSize int64 = 100 * format.MegaByte minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte maxDownloadPartSize int64 = 1000 * format.MegaByte
) )
// isHuggingFaceURL returns true if the URL is from a HuggingFace domain.
// This includes:
// - huggingface.co (main domain)
// - *.huggingface.co (subdomains like cdn-lfs.huggingface.co)
// - hf.co (shortlink domain)
// - *.hf.co (CDN domains like cdn-lfs.hf.co, cdn-lfs3.hf.co)
func isHuggingFaceURL(u *url.URL) bool {
if u == nil {
return false
}
host := strings.ToLower(u.Hostname())
return host == "huggingface.co" ||
strings.HasSuffix(host, ".huggingface.co") ||
host == "hf.co" ||
strings.HasSuffix(host, ".hf.co")
}
// getNumDownloadParts returns the number of concurrent download parts to use
// for the given URL. HuggingFace URLs use reduced concurrency (default 4) to
// avoid triggering rate limits. This can be overridden via the OLLAMA_HF_CONCURRENCY
// environment variable. For non-HuggingFace URLs, returns the standard concurrency (16).
func getNumDownloadParts(u *url.URL) int {
if isHuggingFaceURL(u) {
if v := os.Getenv("OLLAMA_HF_CONCURRENCY"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return numHFDownloadParts
}
return numDownloadParts
}
func (p *blobDownloadPart) Name() string { func (p *blobDownloadPart) Name() string {
return strings.Join([]string{ return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N), p.blobDownload.Name, "partial", strconv.Itoa(p.N),
@@ -271,7 +308,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
} }
g, inner := errgroup.WithContext(ctx) g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts) concurrency := getNumDownloadParts(directURL)
if concurrency != numDownloadParts {
slog.Info(fmt.Sprintf("using reduced concurrency (%d) for HuggingFace download", concurrency))
}
g.SetLimit(concurrency)
for i := range b.Parts { for i := range b.Parts {
part := b.Parts[i] part := b.Parts[i]
if part.Completed.Load() == part.Size { if part.Completed.Load() == part.Size {

194
server/download_test.go Normal file
View File

@@ -0,0 +1,194 @@
package server
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsHuggingFaceURL(t *testing.T) {
tests := []struct {
name string
url string
expected bool
}{
{
name: "nil url",
url: "",
expected: false,
},
{
name: "huggingface.co main domain",
url: "https://huggingface.co/some/model",
expected: true,
},
{
name: "cdn-lfs.huggingface.co subdomain",
url: "https://cdn-lfs.huggingface.co/repos/abc/123",
expected: true,
},
{
name: "cdn-lfs3.hf.co CDN domain",
url: "https://cdn-lfs3.hf.co/repos/abc/123",
expected: true,
},
{
name: "hf.co shortlink domain",
url: "https://hf.co/model",
expected: true,
},
{
name: "uppercase HuggingFace domain",
url: "https://HUGGINGFACE.CO/model",
expected: true,
},
{
name: "mixed case HF domain",
url: "https://Cdn-Lfs.HF.Co/repos",
expected: true,
},
{
name: "ollama registry",
url: "https://registry.ollama.ai/v2/library/llama3",
expected: false,
},
{
name: "github.com",
url: "https://github.com/ollama/ollama",
expected: false,
},
{
name: "fake huggingface domain",
url: "https://nothuggingface.co/model",
expected: false,
},
{
name: "fake hf domain",
url: "https://nothf.co/model",
expected: false,
},
{
name: "huggingface in path not host",
url: "https://example.com/huggingface.co/model",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var u *url.URL
if tc.url != "" {
var err error
u, err = url.Parse(tc.url)
if err != nil {
t.Fatalf("failed to parse URL: %v", err)
}
}
got := isHuggingFaceURL(u)
assert.Equal(t, tc.expected, got)
})
}
}
func TestGetNumDownloadParts(t *testing.T) {
tests := []struct {
name string
url string
envValue string
expected int
description string
}{
{
name: "nil url returns default",
url: "",
envValue: "",
expected: numDownloadParts,
description: "nil URL should return standard concurrency",
},
{
name: "ollama registry returns default",
url: "https://registry.ollama.ai/v2/library/llama3",
envValue: "",
expected: numDownloadParts,
description: "Ollama registry should use standard concurrency",
},
{
name: "huggingface returns reduced default",
url: "https://huggingface.co/model/repo",
envValue: "",
expected: numHFDownloadParts,
description: "HuggingFace should use reduced concurrency",
},
{
name: "hf.co CDN returns reduced default",
url: "https://cdn-lfs3.hf.co/repos/abc/123",
envValue: "",
expected: numHFDownloadParts,
description: "HuggingFace CDN should use reduced concurrency",
},
{
name: "huggingface with env override",
url: "https://huggingface.co/model/repo",
envValue: "2",
expected: 2,
description: "OLLAMA_HF_CONCURRENCY should override default",
},
{
name: "huggingface with higher env override",
url: "https://huggingface.co/model/repo",
envValue: "8",
expected: 8,
description: "OLLAMA_HF_CONCURRENCY can be set higher than default",
},
{
name: "huggingface with invalid env (non-numeric)",
url: "https://huggingface.co/model/repo",
envValue: "invalid",
expected: numHFDownloadParts,
description: "Invalid OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "huggingface with invalid env (zero)",
url: "https://huggingface.co/model/repo",
envValue: "0",
expected: numHFDownloadParts,
description: "Zero OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "huggingface with invalid env (negative)",
url: "https://huggingface.co/model/repo",
envValue: "-1",
expected: numHFDownloadParts,
description: "Negative OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "non-huggingface ignores env",
url: "https://registry.ollama.ai/v2/library/llama3",
envValue: "2",
expected: numDownloadParts,
description: "OLLAMA_HF_CONCURRENCY should not affect non-HF URLs",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Set or clear the environment variable
if tc.envValue != "" {
t.Setenv("OLLAMA_HF_CONCURRENCY", tc.envValue)
}
var u *url.URL
if tc.url != "" {
var err error
u, err = url.Parse(tc.url)
if err != nil {
t.Fatalf("failed to parse URL: %v", err)
}
}
got := getNumDownloadParts(u)
assert.Equal(t, tc.expected, got, tc.description)
})
}
}

View File

@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm, Realm: challenge.Realm,
Service: challenge.Service, Service: challenge.Service,
Scope: challenge.Scope, Scope: challenge.Scope,
}, base.Host) })
} }
if err := transfer.Download(ctx, transfer.DownloadOptions{ if err := transfer.Download(ctx, transfer.DownloadOptions{
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm, Realm: challenge.Realm,
Service: challenge.Service, Service: challenge.Service,
Scope: challenge.Scope, Scope: challenge.Scope,
}, base.Host) })
} }
return transfer.Upload(ctx, transfer.UploadOptions{ return transfer.Upload(ctx, transfer.UploadOptions{
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
// Handle authentication error with one retry // Handle authentication error with one retry
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host) token, err := getAuthorizationToken(ctx, challenge)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -51,6 +51,7 @@ import (
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
) )
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -163,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil return runner.llama, model, &opts, nil
} }
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) { func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey() pubKey, err := auth.GetPublicKey()
if err != nil { if err != nil {
@@ -190,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model) name := model.ParseName(req.Model)
if !name.IsValid() { if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with // Ideally this is "invalid model name" but we're keeping with
@@ -1557,12 +1587,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Experimental OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", s.handleImageGeneration)
// Inference (Anthropic compatibility) // Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil { if rc != nil {
// wrap old with new // wrap old with new
rs := &registry.Local{ rs := &registry.Local{
@@ -1880,62 +1911,6 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b)) return "call_" + strings.ToLower(string(b))
} }
func (s *Server) handleImageGeneration(c *gin.Context) {
var req struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
m, err := GetModel(req.Model)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Parse size (e.g., "1024x768") into width and height
width, height := int32(1024), int32(1024)
if req.Size != "" {
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
return
}
}
var image []byte
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: width,
Height: height,
}, func(resp llm.CompletionResponse) {
if len(resp.Image) > 0 {
image = resp.Image
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"created": time.Now().Unix(),
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
})
}
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()

View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"os" "os"
"slices"
"testing" "testing"
"time" "time"
@@ -16,6 +17,7 @@ import (
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@@ -805,8 +807,32 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
func (s *mockLlm) HasExited() bool { return false } func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil } func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model // TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted when idle. // loaded in the scheduler can be evicted by a language model request.
func TestImageGenRunnerCanBeEvicted(t *testing.T) { func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond) ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done() defer done()
@@ -838,59 +864,3 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
require.NotNil(t, runner) require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath) require.Equal(t, "/fake/image/model", runner.modelPath)
} }
// TestImageGenSchedulerCoexistence verifies that image generation models
// can coexist with language models in the scheduler and VRAM is tracked correctly.
func TestImageGenSchedulerCoexistence(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Load both an imagegen runner and a language model runner
imageGenRunner := &runnerRef{
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
modelPath: "/fake/flux/model",
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
langModelRunner := &runnerRef{
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
modelPath: "/fake/llama3/model",
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
s.loadedMu.Lock()
s.loaded["/fake/flux/model"] = imageGenRunner
s.loaded["/fake/llama3/model"] = langModelRunner
s.loadedMu.Unlock()
// Verify both are loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
require.NotNil(t, s.loaded["/fake/flux/model"])
require.NotNil(t, s.loaded["/fake/llama3/model"])
s.loadedMu.Unlock()
// Verify updateFreeSpace accounts for both
gpus := []ml.DeviceInfo{
{
DeviceID: ml.DeviceID{Library: "Metal"},
TotalMemory: 24 * format.GigaByte,
FreeMemory: 24 * format.GigaByte,
},
}
s.updateFreeSpace(gpus)
// Free memory should be reduced by both models
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
require.Equal(t, expectedFree, gpus[0].FreeMemory)
}

View File

@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
w.Rollback() w.Rollback()
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host) token, err := getAuthorizationToken(ctx, challenge)
if err != nil { if err != nil {
return err return err
} }

50
x/README.md Normal file
View File

@@ -0,0 +1,50 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
### Building ollama-mlx
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
#### macOS (Apple Silicon and Intel)
```bash
# Build MLX backend libraries
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
go build -tags mlx -o ollama-mlx .
```
#### Linux (CUDA)
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
```bash
# Build MLX backend libraries with CUDA support
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
go build -tags mlx -o ollama-mlx .
```
#### Using build scripts
The build scripts automatically create the `ollama-mlx` binary:
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
## Image Generation
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.

View File

@@ -25,6 +25,14 @@ import (
"github.com/ollama/ollama/x/tools" "github.com/ollama/ollama/x/tools"
) )
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants // Tool output capping constants
const ( const (
// localModelTokenLimit is the token limit for local models (smaller context). // localModelTokenLimit is the token limit for local models (smaller context).
@@ -648,7 +656,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
Placeholder: "Send a message (/? for help)", Placeholder: "Send a message (/? for help)",
AltPlaceholder: "Press Enter to send", AltPlaceholder: `Use """ to end multi-line input`,
}) })
if err != nil { if err != nil {
return err return err
@@ -699,6 +707,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var sb strings.Builder var sb strings.Builder
var format string var format string
var system string var system string
var multiline MultilineState = MultilineNone
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@@ -712,12 +721,37 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
} }
scanner.Prompt.UseAlt = false scanner.Prompt.UseAlt = false
sb.Reset() sb.Reset()
multiline = MultilineNone
continue continue
case err != nil: case err != nil:
return err return err
} }
switch { switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil return nil
case strings.HasPrefix(line, "/clear"): case strings.HasPrefix(line, "/clear"):
@@ -826,18 +860,41 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
options[args[2]] = fp[args[2]] options[args[2]] = fp[args[2]]
case "system": case "system":
if len(args) < 3 { if len(args) < 3 {
fmt.Println("Usage: /set system <message>") fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
continue continue
} }
system = strings.Join(args[2:], " ") multiline = MultilineSystem
newMessage := api.Message{Role: "system", Content: system}
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(messages) > 0 && messages[len(messages)-1].Role == "system" { if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage messages[len(messages)-1] = newMessage
} else { } else {
messages = append(messages, newMessage) messages = append(messages, newMessage)
} }
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset()
continue continue
default: default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
@@ -1024,7 +1081,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line) sb.WriteString(line)
} }
if sb.Len() > 0 { if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()} newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage) messages = append(messages, newMessage)

231
x/imagegen/api/handler.go Normal file
View File

@@ -0,0 +1,231 @@
package api
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/imagegen"
)
// RunnerScheduler is the interface for scheduling a model runner.
// This is implemented by server.Server to avoid circular imports.
type RunnerScheduler interface {
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
}
// RegisterRoutes registers the image generation API routes.
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
r.POST("/v1/images/generations", func(c *gin.Context) {
ImageGenerationHandler(c, scheduler)
})
}
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
var req ImageGenerationRequest
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Validate required fields
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
return
}
if req.Prompt == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
return
}
// Apply defaults
if req.N == 0 {
req.N = 1
}
if req.Size == "" {
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
}
// Verify model exists
if imagegen.ResolveModelName(req.Model) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
} else {
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
}
}
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
c.SSEvent("progress", progress)
c.Writer.Flush()
}
}
})
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
c.SSEvent("done", buildResponse(imageBase64, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
return
}
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
}
return ""
}
func parseProgress(content string) ImageProgressEvent {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imageBase64, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imageBase64 == "" {
return resp
}
if format == "url" {
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
} else {
resp.Data[0].B64JSON = imageBase64
}
return resp
}
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
ch := make(chan any)
go func() {
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
// Log error but don't block - channel is already being consumed
_ = err
}
}()
streamFn(c, ch)
}
// GenerateResponse matches api.GenerateResponse structure for streaming.
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}

31
x/imagegen/api/types.go Normal file
View File

@@ -0,0 +1,31 @@
// Package api provides OpenAI-compatible image generation API types.
package api
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
// ImageData contains the generated image data.
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageProgressEvent is sent during streaming to indicate generation progress.
type ImageProgressEvent struct {
Step int `json:"step"`
Total int `json:"total"`
}

View File

@@ -7,6 +7,7 @@ package imagegen
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -38,17 +39,77 @@ func DefaultOptions() ImageGenOptions {
return ImageGenOptions{ return ImageGenOptions{
Width: 1024, Width: 1024,
Height: 1024, Height: 1024,
Steps: 0, // 0 means model default Steps: 9,
Seed: 0, // 0 means random Seed: 0, // 0 means random
} }
} }
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}
// RegisterFlags adds image generation flags to the given command. // RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models. // Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) { func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width") cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height") cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)") cmd.Flags().Int("steps", 9, "Denoising steps")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)") cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt") cmd.Flags().String("negative", "", "Negative prompt")
cmd.Flags().MarkHidden("width") cmd.Flags().MarkHidden("width")
@@ -91,18 +152,23 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
} }
// generateImageWithOptions generates an image with the given options. // generateImageWithOptions generates an image with the given options.
// Note: opts are currently unused as the native API doesn't support size parameters. func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{ req := &api.GenerateRequest{
Model: modelName, Model: modelName,
Prompt: prompt, Prompt: prompt,
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
} }
if keepAlive != nil { if keepAlive != nil {
req.KeepAlive = keepAlive req.KeepAlive = keepAlive

View File

@@ -12,7 +12,6 @@ import (
"path/filepath" "path/filepath"
"runtime/pprof" "runtime/pprof"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3" "github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss" "github.com/ollama/ollama/x/imagegen/models/gpt_oss"
@@ -49,7 +48,7 @@ func main() {
// Image generation params // Image generation params
width := flag.Int("width", 1024, "Image width") width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height") height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)") steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed") seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path") out := flag.String("output", "output.png", "Output path")

View File

@@ -175,63 +175,3 @@ func (m *ModelManifest) HasTensorLayers() bool {
} }
return false return false
} }
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}

View File

@@ -95,3 +95,8 @@ func EstimateVRAM(modelName string) uint64 {
} }
return 21 * GB return 21 * GB
} }
// HasTensorLayers checks if the given model has tensor layers.
func HasTensorLayers(modelName string) bool {
return ResolveModelName(modelName) != ""
}

View File

@@ -94,6 +94,13 @@ func TestEstimateVRAMDefault(t *testing.T) {
} }
} }
func TestHasTensorLayers(t *testing.T) {
// Non-existent model should return false
if HasTensorLayers("nonexistent-model") {
t.Error("HasTensorLayers() should return false for non-existent model")
}
}
func TestResolveModelName(t *testing.T) { func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string // Non-existent model should return empty string
result := ResolveModelName("nonexistent-model") result := ResolveModelName("nonexistent-model")

View File

@@ -9,7 +9,6 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache" "github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -173,7 +172,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
cfg.Height = 1024 cfg.Height = 1024
} }
if cfg.Steps <= 0 { if cfg.Steps <= 0 {
cfg.Steps = 50 cfg.Steps = 30
} }
if cfg.CFGScale <= 0 { if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0 cfg.CFGScale = 4.0

View File

@@ -194,7 +194,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
cfg.Height = 1024 cfg.Height = 1024
} }
if cfg.Steps <= 0 { if cfg.Steps <= 0 {
cfg.Steps = 9 // Z-Image turbo default cfg.Steps = 9 // Turbo default
} }
if cfg.CFGScale <= 0 { if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0 cfg.CFGScale = 4.0

View File

@@ -136,8 +136,16 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Model applies its own defaults for width/height/steps // Apply defaults
// Only seed needs to be set here if not provided if req.Width <= 0 {
req.Width = 1024
}
if req.Height <= 0 {
req.Height = 1024
}
if req.Steps <= 0 {
req.Steps = 9
}
if req.Seed <= 0 { if req.Seed <= 0 {
req.Seed = time.Now().UnixNano() req.Seed = time.Now().UnixNano()
} }

View File

@@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -26,11 +25,6 @@ import (
) )
// Server wraps an image generation subprocess to implement llm.LlamaServer. // Server wraps an image generation subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
// like any other model. The plan is to eventually bring this into the llm/ package
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
// separate allows for independent iteration on image generation support.
type Server struct { type Server struct {
mu sync.Mutex mu sync.Mutex
cmd *exec.Cmd cmd *exec.Cmd
@@ -43,6 +37,22 @@ type Server struct {
lastErrLock sync.Mutex lastErrLock sync.Mutex
} }
// completionRequest is sent to the subprocess
type completionRequest struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// completionResponse is received from the subprocess
type completionResponse struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"`
Done bool `json:"done"`
}
// NewServer spawns a new image generation subprocess and waits until it's ready. // NewServer spawns a new image generation subprocess and waits until it's ready.
func NewServer(modelName string) (*Server, error) { func NewServer(modelName string) (*Server, error) {
// Validate platform support before attempting to start // Validate platform support before attempting to start
@@ -129,6 +139,7 @@ func NewServer(modelName string) (*Server, error) {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
slog.Warn("image-runner", "msg", line) slog.Warn("image-runner", "msg", line)
// Capture last error line for better error reporting
s.lastErrLock.Lock() s.lastErrLock.Lock()
s.lastErr = line s.lastErr = line
s.lastErrLock.Unlock() s.lastErrLock.Unlock()
@@ -160,6 +171,7 @@ func (s *Server) ModelPath() string {
return s.modelName return s.modelName
} }
// Load is called by the scheduler after the server is created.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) { func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil return nil, nil
} }
@@ -192,16 +204,20 @@ func (s *Server) waitUntilRunning() error {
for { for {
select { select {
case err := <-s.done: case err := <-s.done:
// Include recent stderr lines for better error context // Include last stderr line for better error context
errMsg := s.getLastErr() s.lastErrLock.Lock()
if errMsg != "" { lastErr := s.lastErr
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err) s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
} }
return fmt.Errorf("image runner exited unexpectedly: %w", err) return fmt.Errorf("image runner exited unexpectedly: %w", err)
case <-timeout: case <-timeout:
errMsg := s.getLastErr() s.lastErrLock.Lock()
if errMsg != "" { lastErr := s.lastErr
return fmt.Errorf("timeout waiting for image runner: %s", errMsg) s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
} }
return errors.New("timeout waiting for image runner to start") return errors.New("timeout waiting for image runner to start")
case <-ticker.C: case <-ticker.C:
@@ -213,39 +229,44 @@ func (s *Server) waitUntilRunning() error {
} }
} }
// getLastErr returns the last stderr line. // WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
func (s *Server) getLastErr() string { func (s *Server) WaitUntilRunning(ctx context.Context) error {
s.lastErrLock.Lock() return nil
defer s.lastErrLock.Unlock()
return s.lastErr
} }
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil } // Completion generates an image from the prompt via the subprocess.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed // Build request
if seed == 0 { creq := completionRequest{
seed = time.Now().UnixNano()
}
// Build request for subprocess
creq := struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
Prompt: req.Prompt, Prompt: req.Prompt,
Width: req.Width, Width: 1024,
Height: req.Height, Height: 1024,
Seed: seed, Steps: 9,
Seed: time.Now().UnixNano(),
} }
if req.Options != nil {
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
creq.Width = int32(req.Options.NumCtx)
}
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
creq.Height = int32(req.Options.NumGPU)
}
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
creq.Steps = req.Options.NumPredict
}
if req.Options.Seed > 0 {
creq.Seed = int64(req.Options.Seed)
}
}
// Encode request body
body, err := json.Marshal(creq) body, err := json.Marshal(creq)
if err != nil { if err != nil {
return err return err
} }
// Send request to subprocess
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil { if err != nil {
@@ -260,40 +281,30 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("request failed: %d", resp.StatusCode) return fmt.Errorf("completion request failed: %d", resp.StatusCode)
} }
// Stream responses - use large buffer for base64 image data
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() { for scanner.Scan() {
// Parse subprocess response (has singular "image" field) var cresp completionResponse
var raw struct { if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
continue continue
} }
// Convert to llm.CompletionResponse content := cresp.Content
cresp := llm.CompletionResponse{ // If this is the final response with an image, encode it in the content
Content: raw.Content, if cresp.Done && cresp.Image != "" {
Done: raw.Done, content = "IMAGE_BASE64:" + cresp.Image
Step: raw.Step,
Total: raw.Total,
}
if raw.Image != "" {
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
cresp.Image = data
}
} }
fn(cresp) fn(llm.CompletionResponse{
Content: content,
Done: cresp.Done,
})
if cresp.Done { if cresp.Done {
return nil break
} }
} }
@@ -335,18 +346,22 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize return s.vramSize
} }
// Embedding is not supported for image generation models.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) { func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("not supported") return nil, 0, errors.New("embedding not supported for image generation models")
} }
// Tokenize is not supported for image generation models.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
return nil, errors.New("not supported") return nil, errors.New("tokenize not supported for image generation models")
} }
// Detokenize is not supported for image generation models.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) { func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("not supported") return "", errors.New("detokenize not supported for image generation models")
} }
// Pid returns the subprocess PID.
func (s *Server) Pid() int { func (s *Server) Pid() int {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -356,9 +371,17 @@ func (s *Server) Pid() int {
return -1 return -1
} }
func (s *Server) GetPort() int { return s.port } // GetPort returns the subprocess port.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil } func (s *Server) GetPort() int {
return s.port
}
// GetDeviceInfos returns nil since we don't track GPU info.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// HasExited returns true if the subprocess has exited.
func (s *Server) HasExited() bool { func (s *Server) HasExited() bool {
select { select {
case <-s.done: case <-s.done: