Compare commits
32 Commits
parth/decr
...
hoyyeva/up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33c974f331 | ||
|
|
4673dfaeaa | ||
|
|
50ad3f740e | ||
|
|
9c30267cca | ||
|
|
6338186fe2 | ||
|
|
7d56041abf | ||
|
|
211c52abaa | ||
|
|
833dabdcde | ||
|
|
cf7cd4a12e | ||
|
|
cd89530ead | ||
|
|
1472cac5c9 | ||
|
|
8e61ff77d9 | ||
|
|
11b7ff2fde | ||
|
|
011b7de573 | ||
|
|
3da690abac | ||
|
|
c167124c02 | ||
|
|
e68a23e3c5 | ||
|
|
c23d5095de | ||
|
|
7601f0e93e | ||
|
|
aad3f03890 | ||
|
|
55d0b6e8b9 | ||
|
|
38eac40d56 | ||
|
|
80f3f1bc25 | ||
|
|
b1a0db547b | ||
|
|
75d7b5f926 | ||
|
|
349d814814 | ||
|
|
c8743031e0 | ||
|
|
4adb9cf4bb | ||
|
|
74f475e735 | ||
|
|
875cecba74 | ||
|
|
7d411a4686 | ||
|
|
02a2401596 |
@@ -190,7 +190,7 @@ if(MLX_ENGINE)
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
|
||||
43
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
||||
|
||||
## Model library
|
||||
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
||||
|
||||
Here are some example models that can be downloaded:
|
||||
|
||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
## Building with MLX (experimental)
|
||||
|
||||
First build the MLX libraries:
|
||||
|
||||
```shell
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
|
||||
```shell
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
```
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
### Building MLX with CUDA
|
||||
|
||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
@@ -421,7 +454,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
@@ -253,6 +253,8 @@ func main() {
|
||||
done <- osrv.Run(octx)
|
||||
}()
|
||||
|
||||
upd := &updater.Updater{Store: st}
|
||||
|
||||
uiServer := ui.Server{
|
||||
Token: token,
|
||||
Restart: func() {
|
||||
@@ -267,6 +269,10 @@ func main() {
|
||||
ToolRegistry: toolRegistry,
|
||||
Dev: devMode,
|
||||
Logger: slog.Default(),
|
||||
Updater: upd,
|
||||
UpdateAvailableFunc: func() {
|
||||
UpdateAvailable("")
|
||||
},
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
@@ -284,8 +290,13 @@ func main() {
|
||||
slog.Debug("background desktop server done")
|
||||
}()
|
||||
|
||||
updater := &updater.Updater{Store: st}
|
||||
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||
|
||||
// Check for pending updates on startup (show tray notification if update is ready)
|
||||
if updater.IsUpdatePending() {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
|
||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||
if err != nil {
|
||||
@@ -348,6 +359,18 @@ func startHiddenTasks() {
|
||||
// CLI triggered app startup use-case
|
||||
slog.Info("deferring pending update for fast startup")
|
||||
} else {
|
||||
// Check if auto-update is enabled before automatically upgrading
|
||||
st := &store.Store{}
|
||||
settings, err := st.Settings()
|
||||
if err != nil {
|
||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||
} else if !settings.AutoUpdateEnabled {
|
||||
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
|
||||
// Still show tray notification so user knows update is ready
|
||||
UpdateAvailable("")
|
||||
return
|
||||
}
|
||||
|
||||
if err := updater.DoUpgradeAtStartup(); err != nil {
|
||||
slog.Info("unable to perform upgrade at startup", "error", err)
|
||||
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
||||
|
||||
@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
|
||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||
@property(assign, nonatomic) BOOL updateAvailable;
|
||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
||||
@end
|
||||
|
||||
@implementation AppDelegate
|
||||
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||
// Register for system shutdown/restart notification so we can allow termination
|
||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
||||
addObserver:self
|
||||
selector:@selector(systemWillPowerOff:)
|
||||
name:NSWorkspaceWillPowerOffNotification
|
||||
object:nil];
|
||||
|
||||
// if we're in development mode, set the app icon
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
if (![bundlePath hasSuffix:@".app"]) {
|
||||
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
}
|
||||
|
||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
||||
// The system will call applicationShouldTerminate: after posting this notification.
|
||||
self.systemShutdownInProgress = YES;
|
||||
}
|
||||
|
||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||
// Allow termination if the system is shutting down or restarting
|
||||
if (self.systemShutdownInProgress) {
|
||||
return NSTerminateNow;
|
||||
}
|
||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
||||
[NSApp hide:nil];
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||
return NSTerminateCancel;
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// currentSchemaVersion defines the current database schema version.
|
||||
// Increment this when making schema changes that require migrations.
|
||||
const currentSchemaVersion = 12
|
||||
const currentSchemaVersion = 13
|
||||
|
||||
// database wraps the SQLite connection.
|
||||
// SQLite handles its own locking for concurrent access:
|
||||
@@ -85,6 +85,7 @@ func (db *database) init() error {
|
||||
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
think_level TEXT NOT NULL DEFAULT '',
|
||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
schema_version INTEGER NOT NULL DEFAULT %d
|
||||
);
|
||||
|
||||
@@ -244,6 +245,12 @@ func (db *database) migrate() error {
|
||||
return fmt.Errorf("migrate v11 to v12: %w", err)
|
||||
}
|
||||
version = 12
|
||||
case 12:
|
||||
// add auto_update_enabled column to settings table
|
||||
if err := db.migrateV12ToV13(); err != nil {
|
||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
||||
}
|
||||
version = 13
|
||||
default:
|
||||
// If we have a version we don't recognize, just set it to current
|
||||
// This might happen during development
|
||||
@@ -452,6 +459,21 @@ func (db *database) migrateV11ToV12() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV12ToV13 adds the auto_update_enabled column to the settings table
|
||||
func (db *database) migrateV12ToV13() error {
|
||||
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
|
||||
if err != nil && !duplicateColumnError(err) {
|
||||
return fmt.Errorf("add auto_update_enabled column: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||
func (db *database) cleanupOrphanedData() error {
|
||||
_, err := db.conn.Exec(`
|
||||
@@ -482,19 +504,11 @@ func (db *database) cleanupOrphanedData() error {
|
||||
}
|
||||
|
||||
func duplicateColumnError(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
||||
}
|
||||
return false
|
||||
return err != nil && strings.Contains(err.Error(), "duplicate column name")
|
||||
}
|
||||
|
||||
func columnNotExists(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "no such column")
|
||||
}
|
||||
return false
|
||||
return err != nil && strings.Contains(err.Error(), "no such column")
|
||||
}
|
||||
|
||||
func (db *database) getAllChats() ([]Chat, error) {
|
||||
@@ -1108,9 +1122,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||
var s Settings
|
||||
|
||||
err := db.conn.QueryRow(`
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
||||
FROM settings
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
||||
if err != nil {
|
||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||
}
|
||||
@@ -1121,8 +1135,8 @@ func (db *database) getSettings() (Settings, error) {
|
||||
func (db *database) setSettings(s Settings) error {
|
||||
_, err := db.conn.Exec(`
|
||||
UPDATE settings
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set settings: %w", err)
|
||||
}
|
||||
|
||||
@@ -169,6 +169,9 @@ type Settings struct {
|
||||
|
||||
// SidebarOpen indicates if the chat sidebar is open
|
||||
SidebarOpen bool
|
||||
|
||||
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
||||
AutoUpdateEnabled bool
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
|
||||
@@ -413,6 +413,7 @@ export class Settings {
|
||||
ThinkLevel: string;
|
||||
SelectedModel: string;
|
||||
SidebarOpen: boolean;
|
||||
AutoUpdateEnabled: boolean;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
@@ -431,6 +432,7 @@ export class Settings {
|
||||
this.ThinkLevel = source["ThinkLevel"];
|
||||
this.SelectedModel = source["SelectedModel"];
|
||||
this.SidebarOpen = source["SidebarOpen"];
|
||||
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
||||
}
|
||||
}
|
||||
export class SettingsResponse {
|
||||
@@ -467,6 +469,46 @@ export class HealthResponse {
|
||||
this.healthy = source["healthy"];
|
||||
}
|
||||
}
|
||||
export class UpdateInfo {
|
||||
currentVersion: string;
|
||||
availableVersion: string;
|
||||
updateAvailable: boolean;
|
||||
updateDownloaded: boolean;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.currentVersion = source["currentVersion"];
|
||||
this.availableVersion = source["availableVersion"];
|
||||
this.updateAvailable = source["updateAvailable"];
|
||||
this.updateDownloaded = source["updateDownloaded"];
|
||||
}
|
||||
}
|
||||
export class UpdateCheckResponse {
|
||||
updateInfo: UpdateInfo;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.updateInfo = this.convertValues(source["updateInfo"], UpdateInfo);
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
if (!a) {
|
||||
return a;
|
||||
}
|
||||
if (Array.isArray(a)) {
|
||||
return (a as any[]).map(elem => this.convertValues(elem, classs));
|
||||
} else if ("object" === typeof a) {
|
||||
if (asMap) {
|
||||
for (const key of Object.keys(a)) {
|
||||
a[key] = new classs(a[key]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
return new classs(a);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
}
|
||||
export class User {
|
||||
id: string;
|
||||
email: string;
|
||||
|
||||
@@ -414,3 +414,54 @@ export async function fetchHealth(): Promise<boolean> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function getCurrentVersion(): Promise<string> {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/api/version`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
return data.version || "Unknown";
|
||||
}
|
||||
return "Unknown";
|
||||
} catch (error) {
|
||||
console.error("Error fetching version:", error);
|
||||
return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkForUpdate(): Promise<{
|
||||
currentVersion: string;
|
||||
availableVersion: string;
|
||||
updateAvailable: boolean;
|
||||
updateDownloaded: boolean;
|
||||
}> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/update/check`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to check for update");
|
||||
}
|
||||
const data = await response.json();
|
||||
return data.updateInfo;
|
||||
}
|
||||
|
||||
export async function installUpdate(): Promise<void> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/update/install`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(error || "Failed to install update");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,12 +14,13 @@ import {
|
||||
XMarkIcon,
|
||||
CogIcon,
|
||||
ArrowLeftIcon,
|
||||
ArrowDownTrayIcon,
|
||||
} from "@heroicons/react/20/solid";
|
||||
import { Settings as SettingsType } from "@/gotypes";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
import { useUser } from "@/hooks/useUser";
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { getSettings, updateSettings } from "@/api";
|
||||
import { getSettings, updateSettings, checkForUpdate } from "@/api";
|
||||
|
||||
function AnimatedDots() {
|
||||
return (
|
||||
@@ -39,6 +40,12 @@ export default function Settings() {
|
||||
const queryClient = useQueryClient();
|
||||
const [showSaved, setShowSaved] = useState(false);
|
||||
const [restartMessage, setRestartMessage] = useState(false);
|
||||
const [updateInfo, setUpdateInfo] = useState<{
|
||||
currentVersion: string;
|
||||
availableVersion: string;
|
||||
updateAvailable: boolean;
|
||||
updateDownloaded: boolean;
|
||||
} | null>(null);
|
||||
const {
|
||||
user,
|
||||
isAuthenticated,
|
||||
@@ -76,6 +83,10 @@ export default function Settings() {
|
||||
|
||||
useEffect(() => {
|
||||
refetchUser();
|
||||
// Check for updates on mount
|
||||
checkForUpdate()
|
||||
.then(setUpdateInfo)
|
||||
.catch((err) => console.error("Error checking for update:", err));
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
useEffect(() => {
|
||||
@@ -344,6 +355,58 @@ export default function Settings() {
|
||||
{/* Local Configuration */}
|
||||
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||
<div className="space-y-4 p-4">
|
||||
{/* Auto Update */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="flex items-start space-x-3 flex-1">
|
||||
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||
<div className="flex-1">
|
||||
<Label>Auto-download updates</Label>
|
||||
<Description>
|
||||
{settings.AutoUpdateEnabled ? (
|
||||
<>
|
||||
Automatically downloads updates when available.
|
||||
<div className="mt-2 text-xs text-zinc-600 dark:text-zinc-400">
|
||||
Current version: {updateInfo?.currentVersion || "Loading..."}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
Manually download updates.
|
||||
<div className="mt-3 p-3 bg-zinc-50 dark:bg-zinc-900 rounded-lg border border-zinc-200 dark:border-zinc-800">
|
||||
<div className="space-y-2 text-sm">
|
||||
<div className="flex justify-between">
|
||||
<span className="text-zinc-600 dark:text-zinc-400">Current version: {updateInfo?.currentVersion || "Loading..."}</span>
|
||||
</div>
|
||||
{updateInfo?.availableVersion && (
|
||||
<div className="flex justify-between">
|
||||
<span className="text-zinc-600 dark:text-zinc-400">Available version: {updateInfo?.availableVersion}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<a
|
||||
href="https://ollama.com/download"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="mt-3 inline-block text-sm text-neutral-600 dark:text-neutral-400 underline"
|
||||
>
|
||||
Download new version →
|
||||
</a>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<Switch
|
||||
checked={settings.AutoUpdateEnabled}
|
||||
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
{/* Expose Ollama */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
|
||||
@@ -100,6 +100,17 @@ type HealthResponse struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
}
|
||||
|
||||
type UpdateInfo struct {
|
||||
CurrentVersion string `json:"currentVersion"`
|
||||
AvailableVersion string `json:"availableVersion"`
|
||||
UpdateAvailable bool `json:"updateAvailable"`
|
||||
UpdateDownloaded bool `json:"updateDownloaded"`
|
||||
}
|
||||
|
||||
type UpdateCheckResponse struct {
|
||||
UpdateInfo UpdateInfo `json:"updateInfo"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
|
||||
100
app/ui/ui.go
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/app/tools"
|
||||
"github.com/ollama/ollama/app/types/not"
|
||||
"github.com/ollama/ollama/app/ui/responses"
|
||||
"github.com/ollama/ollama/app/updater"
|
||||
"github.com/ollama/ollama/app/version"
|
||||
ollamaAuth "github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -106,6 +107,18 @@ type Server struct {
|
||||
|
||||
// Dev is true if the server is running in development mode
|
||||
Dev bool
|
||||
|
||||
// Updater for checking and downloading updates
|
||||
Updater UpdaterInterface
|
||||
UpdateAvailableFunc func()
|
||||
}
|
||||
|
||||
// UpdaterInterface defines the methods we need from the updater
|
||||
type UpdaterInterface interface {
|
||||
CheckForUpdate(ctx context.Context) (bool, string, error)
|
||||
InstallAndRestart() error
|
||||
CancelOngoingDownload()
|
||||
TriggerImmediateCheck()
|
||||
}
|
||||
|
||||
func (s *Server) log() *slog.Logger {
|
||||
@@ -284,6 +297,8 @@ func (s *Server) Handler() http.Handler {
|
||||
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
||||
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
||||
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
||||
mux.Handle("GET /api/v1/update/check", handle(s.checkForUpdate))
|
||||
mux.Handle("POST /api/v1/update/install", handle(s.installUpdate))
|
||||
|
||||
// Ollama proxy endpoints
|
||||
ollamaProxy := s.ollamaProxy()
|
||||
@@ -1448,6 +1463,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||
return fmt.Errorf("failed to save settings: %w", err)
|
||||
}
|
||||
|
||||
// Handle auto-update toggle changes
|
||||
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
|
||||
if !settings.AutoUpdateEnabled {
|
||||
// Auto-update disabled: cancel any ongoing download
|
||||
if s.Updater != nil {
|
||||
s.Updater.CancelOngoingDownload()
|
||||
}
|
||||
} else {
|
||||
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
|
||||
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
|
||||
s.UpdateAvailableFunc()
|
||||
} else if s.Updater != nil {
|
||||
// Trigger the background checker to run immediately
|
||||
s.Updater.TriggerImmediateCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if old.ContextLength != settings.ContextLength ||
|
||||
old.Models != settings.Models ||
|
||||
old.Expose != settings.Expose {
|
||||
@@ -1524,6 +1557,73 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
||||
return json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func (s *Server) checkForUpdate(w http.ResponseWriter, r *http.Request) error {
|
||||
currentVersion := version.Version
|
||||
|
||||
if s.Updater == nil {
|
||||
return fmt.Errorf("updater not available")
|
||||
}
|
||||
|
||||
updateAvailable, updateVersion, err := s.Updater.CheckForUpdate(r.Context())
|
||||
if err != nil {
|
||||
s.log().Warn("failed to check for update", "error", err)
|
||||
// Don't return error, just log it and continue with no update available
|
||||
}
|
||||
|
||||
response := responses.UpdateCheckResponse{
|
||||
UpdateInfo: responses.UpdateInfo{
|
||||
CurrentVersion: currentVersion,
|
||||
AvailableVersion: updateVersion,
|
||||
UpdateAvailable: updateAvailable,
|
||||
UpdateDownloaded: updater.UpdateDownloaded,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func (s *Server) installUpdate(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return fmt.Errorf("method not allowed")
|
||||
}
|
||||
|
||||
if s.Updater == nil {
|
||||
s.log().Error("install failed: updater not available")
|
||||
return fmt.Errorf("updater not available")
|
||||
}
|
||||
|
||||
// Check if update is downloaded
|
||||
if !updater.UpdateDownloaded {
|
||||
s.log().Error("install failed: no update downloaded")
|
||||
return fmt.Errorf("no update downloaded")
|
||||
}
|
||||
|
||||
// Send response before restarting
|
||||
response := map[string]any{
|
||||
"success": true,
|
||||
"message": "Installing update and restarting...",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Give the response time to be sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Trigger the upgrade and restart
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := s.Updater.InstallAndRestart(); err != nil {
|
||||
s.log().Error("failed to install update", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func userAgent() string {
|
||||
buildinfo, _ := debug.ReadBuildInfo()
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/app/store"
|
||||
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||
query := requestURL.Query()
|
||||
query.Add("os", runtime.GOOS)
|
||||
query.Add("arch", runtime.GOARCH)
|
||||
query.Add("version", version.Version)
|
||||
currentVersion := version.Version
|
||||
query.Add("version", currentVersion)
|
||||
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
|
||||
// The original macOS app used to use the device ID
|
||||
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||
}
|
||||
|
||||
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||
// Create a cancellable context for this download
|
||||
downloadCtx, cancel := context.WithCancel(ctx)
|
||||
u.cancelDownloadLock.Lock()
|
||||
u.cancelDownload = cancel
|
||||
u.cancelDownloadLock.Unlock()
|
||||
defer func() {
|
||||
u.cancelDownloadLock.Lock()
|
||||
u.cancelDownload = nil
|
||||
u.cancelDownloadLock.Unlock()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Do a head first to check etag info
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// In case of slow downloads, continue the update check in the background
|
||||
bgctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
bgctx, bgcancel := context.WithCancel(downloadCtx)
|
||||
defer bgcancel()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||
_, err = os.Stat(stageFilename)
|
||||
if err == nil {
|
||||
slog.Info("update already downloaded", "bundle", stageFilename)
|
||||
UpdateDownloaded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -244,34 +259,95 @@ func cleanupOldDownloads(stageDir string) {
|
||||
}
|
||||
|
||||
type Updater struct {
|
||||
Store *store.Store
|
||||
Store *store.Store
|
||||
cancelDownload context.CancelFunc
|
||||
cancelDownloadLock sync.Mutex
|
||||
checkNow chan struct{}
|
||||
}
|
||||
|
||||
// CancelOngoingDownload cancels any currently running download
|
||||
func (u *Updater) CancelOngoingDownload() {
|
||||
u.cancelDownloadLock.Lock()
|
||||
defer u.cancelDownloadLock.Unlock()
|
||||
if u.cancelDownload != nil {
|
||||
slog.Info("cancelling ongoing update download")
|
||||
u.cancelDownload()
|
||||
u.cancelDownload = nil
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerImmediateCheck signals the background checker to check for updates immediately
|
||||
func (u *Updater) TriggerImmediateCheck() {
|
||||
if u.checkNow != nil {
|
||||
u.checkNow <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||
u.checkNow = make(chan struct{}, 1)
|
||||
go func() {
|
||||
// Don't blast an update message immediately after startup
|
||||
time.Sleep(UpdateCheckInitialDelay)
|
||||
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
||||
ticker := time.NewTicker(UpdateCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
if available {
|
||||
err := u.DownloadNewRelease(ctx, resp)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
|
||||
} else {
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("stopping background update checker")
|
||||
return
|
||||
default:
|
||||
time.Sleep(UpdateCheckInterval)
|
||||
case <-u.checkNow:
|
||||
// Immediate check triggered
|
||||
case <-ticker.C:
|
||||
// Regular interval check
|
||||
}
|
||||
|
||||
// Always check for updates
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
if !available {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update is available - check if auto-update is enabled for downloading
|
||||
settings, err := u.Store.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to load settings", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !settings.AutoUpdateEnabled {
|
||||
// Auto-update disabled - don't download, just log
|
||||
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
|
||||
continue
|
||||
}
|
||||
|
||||
// Auto-update is enabled - download
|
||||
err = u.DownloadNewRelease(ctx, resp)
|
||||
if err != nil {
|
||||
slog.Error("failed to download new release", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Download successful - show tray notification (regardless of toggle state)
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn("failed to register update available with tray", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (u *Updater) CheckForUpdate(ctx context.Context) (bool, string, error) {
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
return available, resp.UpdateVersion, nil
|
||||
}
|
||||
|
||||
func (u *Updater) InstallAndRestart() error {
|
||||
if !UpdateDownloaded {
|
||||
return fmt.Errorf("no update downloaded")
|
||||
}
|
||||
|
||||
slog.Info("installing update and restarting")
|
||||
return DoUpgrade(true)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -85,7 +86,17 @@ func TestBackgoundChecker(t *testing.T) {
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
defer updater.Store.Close() // Ensure database is closed
|
||||
defer updater.Store.Close()
|
||||
|
||||
settings, err := updater.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := updater.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
@@ -99,3 +110,187 @@ func TestBackgoundChecker(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
var downloadAttempted atomic.Bool
|
||||
done := make(chan struct{})
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
UpdateCheckInitialDelay = 5 * time.Millisecond
|
||||
UpdateCheckInterval = 5 * time.Millisecond
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
downloadAttempted.Store(true)
|
||||
buf := &bytes.Buffer{}
|
||||
zw := zip.NewWriter(buf)
|
||||
zw.Close()
|
||||
io.Copy(w, buf)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
// Ensure auto-update is disabled
|
||||
settings, err := updater.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = false
|
||||
if err := updater.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := func(ver string) error {
|
||||
t.Fatal("callback should not be called when auto-update is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait enough time for multiple check cycles
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
close(done)
|
||||
|
||||
if downloadAttempted.Load() {
|
||||
t.Fatal("download should not be attempted when auto-update is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelOngoingDownload(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
downloadStarted := make(chan struct{})
|
||||
downloadCancelled := make(chan struct{})
|
||||
|
||||
ctx := t.Context()
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("Content-Length", "1000000")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
// Signal that download has started
|
||||
close(downloadStarted)
|
||||
// Wait for cancellation or timeout
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
close(downloadCancelled)
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("download was not cancelled in time")
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
_, resp := updater.checkForUpdate(ctx)
|
||||
|
||||
// Start download in goroutine
|
||||
go func() {
|
||||
_ = updater.DownloadNewRelease(ctx, resp)
|
||||
}()
|
||||
|
||||
// Wait for download to start
|
||||
select {
|
||||
case <-downloadStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("download did not start in time")
|
||||
}
|
||||
|
||||
// Cancel the download
|
||||
updater.CancelOngoingDownload()
|
||||
|
||||
// Verify cancellation was received
|
||||
select {
|
||||
case <-downloadCancelled:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("download cancellation was not received by server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriggerImmediateCheck(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
checkCount := atomic.Int32{}
|
||||
checkDone := make(chan struct{}, 10)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
// Set a very long interval so only TriggerImmediateCheck causes checks
|
||||
UpdateCheckInitialDelay = 1 * time.Millisecond
|
||||
UpdateCheckInterval = 1 * time.Hour
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
checkCount.Add(1)
|
||||
select {
|
||||
case checkDone <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
// Return no update available
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
cb := func(ver string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait for goroutine to start and pass initial delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// With 1 hour interval, no check should have happened yet
|
||||
initialCount := checkCount.Load()
|
||||
|
||||
// Trigger immediate check
|
||||
updater.TriggerImmediateCheck()
|
||||
|
||||
// Wait for the triggered check
|
||||
select {
|
||||
case <-checkDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("triggered check did not happen")
|
||||
}
|
||||
|
||||
finalCount := checkCount.Load()
|
||||
if finalCount <= initialCount {
|
||||
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
|
||||
// const ERROR_SUCCESS syscall.Errno = 0
|
||||
|
||||
// t.muMenus.RLock()
|
||||
// menu := uintptr(t.menus[parentId])
|
||||
// t.muMenus.RUnlock()
|
||||
// res, _, err := pRemoveMenu.Call(
|
||||
// menu,
|
||||
// uintptr(menuItemId),
|
||||
// MF_BYCOMMAND,
|
||||
// )
|
||||
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
||||
// return err
|
||||
// }
|
||||
// t.delFromVisibleItems(parentId, menuItemId)
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (t *winTray) showMenu() error {
|
||||
p := point{}
|
||||
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
||||
|
||||
@@ -51,7 +51,6 @@ const (
|
||||
IMAGE_ICON = 1 // Loads an icon
|
||||
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
||||
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
||||
MF_BYCOMMAND = 0x00000000
|
||||
MFS_DISABLED = 0x00000003
|
||||
MFT_SEPARATOR = 0x00000800
|
||||
MFT_STRING = 0x00000000
|
||||
|
||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -21,6 +21,7 @@ ollama pull glm-4.7:cloud
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
@@ -247,12 +248,13 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const results = await client.webSearch({ query: "what is ollama?" });
|
||||
const results = await client.webSearch("what is ollama?");
|
||||
console.log(JSON.stringify(results, null, 2));
|
||||
```
|
||||
|
||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||
const fetchResult = await client.webFetch("https://ollama.com");
|
||||
console.log(JSON.stringify(fetchResult, null, 2));
|
||||
```
|
||||
|
||||
|
||||
@@ -111,7 +111,9 @@
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode"
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
|
||||
BIN
docs/images/marimo-add-model.png
Normal file
|
After Width: | Height: | Size: 174 KiB |
BIN
docs/images/marimo-chat.png
Normal file
|
After Width: | Height: | Size: 80 KiB |
BIN
docs/images/marimo-code-completion.png
Normal file
|
After Width: | Height: | Size: 230 KiB |
BIN
docs/images/marimo-models.png
Normal file
|
After Width: | Height: | Size: 178 KiB |
BIN
docs/images/marimo-settings.png
Normal file
|
After Width: | Height: | Size: 186 KiB |
BIN
docs/images/onyx-login.png
Normal file
|
After Width: | Height: | Size: 100 KiB |
BIN
docs/images/onyx-ollama-form.png
Normal file
|
After Width: | Height: | Size: 306 KiB |
BIN
docs/images/onyx-ollama-llm.png
Normal file
|
After Width: | Height: | Size: 300 KiB |
BIN
docs/images/onyx-query.png
Normal file
|
After Width: | Height: | Size: 211 KiB |
@@ -25,6 +25,7 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
@@ -38,7 +39,7 @@ claude --model qwen3-coder
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
73
docs/integrations/marimo.mdx
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: marimo
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
||||
|
||||
```
|
||||
uvx marimo edit --sandbox notebook.py
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
||||
you can find and configure Ollama as an AI provider. For local use you
|
||||
would typically point the base url to `http://localhost:11434/v1`.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-settings.png"
|
||||
alt="Ollama settings in marimo"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-models.png"
|
||||
alt="Selecting an Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-add-model.png"
|
||||
alt="Adding a new Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-chat.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-code-completion.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Sign in to ollama cloud via `ollama signin`
|
||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
||||
3. You can now refer to this model in marimo!
|
||||
63
docs/integrations/onyx.mdx
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
title: Onyx
|
||||
---
|
||||
|
||||
## Overview
|
||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
||||
- Creating custom Agents
|
||||
- Web search
|
||||
- Deep Research
|
||||
- RAG over uploaded documents and connected apps
|
||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
||||
- MCP and OpenAPI Actions support
|
||||
- Image generation
|
||||
- User/Groups management, RBAC, SSO, etc.
|
||||
|
||||
Onyx can be deployed for single users or large organizations.
|
||||
|
||||
## Install Onyx
|
||||
|
||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
||||
|
||||
<Info>
|
||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
||||
</Info>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. Login to your Onyx deployment (create an account first).
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-login.png"
|
||||
alt="Onyx Login Page"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. In the set-up process select `Ollama` as the LLM provider.
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-llm.png"
|
||||
alt="Onyx Set Up Form"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Provide your **Ollama API URL** and select your models.
|
||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-form.png"
|
||||
alt="Selecting Ollama Models"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
||||
|
||||
## Send your first query
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-query.png"
|
||||
alt="Onyx Query Example"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Linux"
|
||||
title: Linux
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -40,8 +41,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -112,7 +113,11 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -1464,6 +1464,11 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1512,6 +1517,11 @@ type CompletionResponse struct {
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image generation fields
|
||||
Image []byte `json:"image,omitempty"` // Generated image
|
||||
Step int `json:"step,omitempty"` // Current generation step
|
||||
Total int `json:"total,omitempty"` // Total generation steps
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -441,6 +442,7 @@ type ResponsesWriter struct {
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
@@ -478,7 +480,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -523,11 +527,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
|
||||
@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -265,9 +266,9 @@ type ResponsesText struct {
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Description *string `json:"description"` // nullable but required
|
||||
Strict *bool `json:"strict"` // nullable but required
|
||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
||||
}
|
||||
|
||||
type ResponsesRequest struct {
|
||||
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var description string
|
||||
if t.Description != nil {
|
||||
description = *t.Description
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: t.Type,
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
||||
|
||||
// Response types for the Responses API
|
||||
|
||||
// ResponsesTextField represents the text output configuration in the response.
|
||||
type ResponsesTextField struct {
|
||||
Format ResponsesTextFormat `json:"format"`
|
||||
}
|
||||
|
||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
||||
type ResponsesReasoningOutput struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Summary *string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError represents an error in the response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||
// requires additional plumbing to find the effective values since the
|
||||
// defaults can come from the model or the request
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at"`
|
||||
Status string `json:"status"`
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
||||
Model string `json:"model"`
|
||||
PreviousResponseID *string `json:"previous_response_id"`
|
||||
Instructions *string `json:"instructions"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Error *ResponsesError `json:"error"`
|
||||
Tools []ResponsesTool `json:"tools"`
|
||||
ToolChoice any `json:"tool_choice"`
|
||||
Truncation string `json:"truncation"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
Text ResponsesTextField `json:"text"`
|
||||
TopP float64 `json:"top_p"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
||||
Usage *ResponsesUsage `json:"usage"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
MaxToolCalls *int `json:"max_tool_calls"`
|
||||
Store bool `json:"store"`
|
||||
Background bool `json:"background"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
SafetyIdentifier *string `json:"safety_identifier"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
||||
}
|
||||
|
||||
type ResponsesOutputItem struct {
|
||||
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Annotations []any `json:"annotations"`
|
||||
Logprobs []any `json:"logprobs"`
|
||||
}
|
||||
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
||||
func derefFloat64(p *float64, def float64) float64 {
|
||||
if p != nil {
|
||||
return *p
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
||||
// The request is used to echo back request parameters in the response.
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
||||
var output []ResponsesOutputItem
|
||||
|
||||
// Add reasoning item if thinking is present
|
||||
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
output = append(output, ResponsesOutputItem{
|
||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
|
||||
Role: "assistant",
|
||||
Content: []ResponsesOutputContent{
|
||||
{
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Annotations: []any{},
|
||||
Logprobs: []any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var instructions *string
|
||||
if request.Instructions != "" {
|
||||
instructions = &request.Instructions
|
||||
}
|
||||
|
||||
// Build truncation with default
|
||||
truncation := "disabled"
|
||||
if request.Truncation != nil {
|
||||
truncation = *request.Truncation
|
||||
}
|
||||
|
||||
tools := request.Tools
|
||||
if tools == nil {
|
||||
tools = []ResponsesTool{}
|
||||
}
|
||||
|
||||
text := ResponsesTextField{
|
||||
Format: ResponsesTextFormat{Type: "text"},
|
||||
}
|
||||
if request.Text != nil && request.Text.Format != nil {
|
||||
text.Format = *request.Text.Format
|
||||
}
|
||||
|
||||
// Build reasoning output from request
|
||||
var reasoning *ResponsesReasoningOutput
|
||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
||||
reasoning = &ResponsesReasoningOutput{}
|
||||
if request.Reasoning.Effort != "" {
|
||||
reasoning.Effort = &request.Reasoning.Effort
|
||||
}
|
||||
if request.Reasoning.Summary != "" {
|
||||
reasoning.Summary = &request.Reasoning.Summary
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesResponse{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
Status: "completed",
|
||||
Model: model,
|
||||
Output: output,
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
CompletedAt: nil, // Set by middleware when writing final response
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil, // Only populated if response incomplete
|
||||
Model: model,
|
||||
PreviousResponseID: nil, // Not supported
|
||||
Instructions: instructions,
|
||||
Output: output,
|
||||
Error: nil, // Only populated on failure
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // Default value
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: true, // Default value
|
||||
Text: text,
|
||||
TopP: derefFloat64(request.TopP, 1.0),
|
||||
PresencePenalty: 0, // Default value
|
||||
FrequencyPenalty: 0, // Default value
|
||||
TopLogprobs: 0, // Default value
|
||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
||||
Reasoning: reasoning,
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: chatResponse.PromptEvalCount,
|
||||
OutputTokens: chatResponse.EvalCount,
|
||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||
// TODO(drifkin): wire through the actual values
|
||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
||||
// TODO(drifkin): wire through the actual values
|
||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
||||
},
|
||||
MaxOutputTokens: request.MaxOutputTokens,
|
||||
MaxToolCalls: nil, // Not supported
|
||||
Store: false, // We don't store responses
|
||||
Background: request.Background,
|
||||
ServiceTier: "default", // Default value
|
||||
Metadata: map[string]any{},
|
||||
SafetyIdentifier: nil, // Not supported
|
||||
PromptCacheKey: nil, // Not supported
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
|
||||
responseID string
|
||||
itemID string
|
||||
model string
|
||||
request ResponsesRequest
|
||||
|
||||
// State tracking (mutated across Process calls)
|
||||
firstWrite bool
|
||||
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
||||
}
|
||||
|
||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
||||
return &ResponsesStreamConverter{
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
model: model,
|
||||
request: request,
|
||||
firstWrite: true,
|
||||
}
|
||||
}
|
||||
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
||||
return events
|
||||
}
|
||||
|
||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
||||
var instructions any = nil
|
||||
if c.request.Instructions != "" {
|
||||
instructions = c.request.Instructions
|
||||
}
|
||||
|
||||
truncation := "disabled"
|
||||
if c.request.Truncation != nil {
|
||||
truncation = *c.request.Truncation
|
||||
}
|
||||
|
||||
var tools []any
|
||||
if c.request.Tools != nil {
|
||||
for _, t := range c.request.Tools {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"strict": t.Strict,
|
||||
"parameters": t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if tools == nil {
|
||||
tools = []any{}
|
||||
}
|
||||
|
||||
textFormat := map[string]any{"type": "text"}
|
||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
||||
textFormat = map[string]any{
|
||||
"type": c.request.Text.Format.Type,
|
||||
}
|
||||
if c.request.Text.Format.Name != "" {
|
||||
textFormat["name"] = c.request.Text.Format.Name
|
||||
}
|
||||
if c.request.Text.Format.Schema != nil {
|
||||
textFormat["schema"] = c.request.Text.Format.Schema
|
||||
}
|
||||
if c.request.Text.Format.Strict != nil {
|
||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
||||
}
|
||||
}
|
||||
|
||||
var reasoning any = nil
|
||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
||||
r := map[string]any{}
|
||||
if c.request.Reasoning.Effort != "" {
|
||||
r["effort"] = c.request.Reasoning.Effort
|
||||
} else {
|
||||
r["effort"] = nil
|
||||
}
|
||||
if c.request.Reasoning.Summary != "" {
|
||||
r["summary"] = c.request.Reasoning.Summary
|
||||
} else {
|
||||
r["summary"] = nil
|
||||
}
|
||||
reasoning = r
|
||||
}
|
||||
|
||||
// Build top_p and temperature with defaults
|
||||
topP := 1.0
|
||||
if c.request.TopP != nil {
|
||||
topP = *c.request.TopP
|
||||
}
|
||||
temperature := 1.0
|
||||
if c.request.Temperature != nil {
|
||||
temperature = *c.request.Temperature
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"completed_at": nil,
|
||||
"status": status,
|
||||
"incomplete_details": nil,
|
||||
"model": c.model,
|
||||
"previous_response_id": nil,
|
||||
"instructions": instructions,
|
||||
"output": output,
|
||||
"error": nil,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"truncation": truncation,
|
||||
"parallel_tool_calls": true,
|
||||
"text": map[string]any{"format": textFormat},
|
||||
"top_p": topP,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"top_logprobs": 0,
|
||||
"temperature": temperature,
|
||||
"reasoning": reasoning,
|
||||
"usage": usage,
|
||||
"max_output_tokens": c.request.MaxOutputTokens,
|
||||
"max_tool_calls": nil,
|
||||
"store": false,
|
||||
"background": c.request.Background,
|
||||
"service_tier": "default",
|
||||
"metadata": map[string]any{},
|
||||
"safety_identifier": nil,
|
||||
"prompt_cache_key": nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.created", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.in_progress", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
||||
|
||||
// Emit delta
|
||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"delta": thinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"delta": thinking,
|
||||
}))
|
||||
|
||||
// TODO(drifkin): consider adding
|
||||
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
||||
|
||||
events := []ResponsesStreamEvent{
|
||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"text": c.accumulatedThinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"text": c.accumulatedThinking,
|
||||
}),
|
||||
c.newEvent("response.output_item.done", map[string]any{
|
||||
"output_index": c.outputIndex,
|
||||
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": c.contentIndex,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"delta": content,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
return events
|
||||
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
})
|
||||
}
|
||||
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"text": c.accumulatedText,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
// response.content_part.done
|
||||
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
}},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// response.completed
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": c.buildFinalOutput(),
|
||||
"usage": map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
usage := map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
"output_tokens_details": map[string]any{
|
||||
"reasoning_tokens": 0,
|
||||
},
|
||||
}
|
||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
||||
response["completed_at"] = time.Now().Unix()
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": response,
|
||||
}))
|
||||
|
||||
return events
|
||||
|
||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with content
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk with thinking
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
||||
Content: "The answer is 42",
|
||||
},
|
||||
Done: true,
|
||||
})
|
||||
}, ResponsesRequest{})
|
||||
|
||||
// Should have 2 output items: reasoning + message
|
||||
if len(response.Output) != 2 {
|
||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
// Verify that response.output_item.done includes content field for messages
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// First chunk
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.completed includes the output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
// Process some content
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.created includes an empty output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hi"},
|
||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
// Verify that events include incrementing sequence numbers
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hello"},
|
||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
// Verify that function call items include status field
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -36,10 +37,11 @@ type Terminal struct {
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
pastedLines []string
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
i.pastedLines = nil
|
||||
i.Prompt.UseAlt = false
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
buf.Remove()
|
||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
||||
lastIdx := len(i.pastedLines) - 1
|
||||
prevLine := i.pastedLines[lastIdx]
|
||||
i.pastedLines = i.pastedLines[:lastIdx]
|
||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
||||
if len(i.pastedLines) == 0 {
|
||||
fmt.Print(i.Prompt.Prompt)
|
||||
i.Prompt.UseAlt = false
|
||||
} else {
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
}
|
||||
for _, r := range prevLine {
|
||||
buf.Add(r)
|
||||
}
|
||||
} else {
|
||||
buf.Remove()
|
||||
}
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharEnter, CharCtrlJ:
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
case CharEnter:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
||||
i.pastedLines = nil
|
||||
}
|
||||
if output != "" {
|
||||
i.History.Add(output)
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
i.Prompt.UseAlt = false
|
||||
|
||||
return output, nil
|
||||
default:
|
||||
|
||||
@@ -179,7 +179,7 @@ _build_macapp() {
|
||||
fi
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
@@ -187,7 +187,7 @@ _build_macapp() {
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
|
||||
rm -f dist/Ollama.dmg
|
||||
|
||||
|
||||
@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||
redirectURL, err := challenge.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||
if redirectURL.Host != originalHost {
|
||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
|
||||
|
||||
113
server/auth_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
realm string
|
||||
originalHost string
|
||||
wantMismatch bool
|
||||
}{
|
||||
{"https://example.com/token", "example.com", false},
|
||||
{"https://example.com/token", "other.com", true},
|
||||
{"https://example.com/token", "localhost:8000", true},
|
||||
{"https://localhost:5000/token", "localhost:5000", false},
|
||||
{"https://localhost:5000/token", "localhost:6000", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.originalHost, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||
|
||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||
if tt.wantMismatch && !isMismatch {
|
||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||
}
|
||||
if !tt.wantMismatch && isMismatch {
|
||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantRealm, wantService, wantScope string
|
||||
}{
|
||||
{
|
||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||
},
|
||||
{
|
||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||
},
|
||||
{"", "", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := parseRegistryChallenge(tt.input)
|
||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||
tt.input, result.Realm, result.Service, result.Scope,
|
||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURL(t *testing.T) {
|
||||
challenge := registryChallenge{
|
||||
Realm: "https://auth.example.com/token",
|
||||
Service: "registry",
|
||||
Scope: "repo:foo:pull repo:bar:push",
|
||||
}
|
||||
|
||||
u, err := challenge.URL()
|
||||
if err != nil {
|
||||
t.Fatalf("URL() error: %v", err)
|
||||
}
|
||||
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||
}
|
||||
if u.Path != "/token" {
|
||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
if q.Get("service") != "registry" {
|
||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||
}
|
||||
if scopes := q["scope"]; len(scopes) != 2 {
|
||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||
}
|
||||
if q.Get("ts") == "" {
|
||||
t.Error("missing ts")
|
||||
}
|
||||
if q.Get("nonce") == "" {
|
||||
t.Error("missing nonce")
|
||||
}
|
||||
|
||||
// Nonces should differ between calls
|
||||
u2, _ := challenge.URL()
|
||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||
t.Error("nonce should be unique per call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||
challenge := registryChallenge{Realm: "://invalid"}
|
||||
if _, err := challenge.URL(); err == nil {
|
||||
t.Error("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
@@ -95,48 +95,11 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
||||
}
|
||||
|
||||
const (
|
||||
// numDownloadParts is the default number of concurrent download parts for standard downloads
|
||||
numDownloadParts = 16
|
||||
// numHFDownloadParts is the reduced number of concurrent download parts for HuggingFace
|
||||
// downloads to avoid triggering rate limits (HTTP 429 errors). See GitHub issue #13297.
|
||||
numHFDownloadParts = 4
|
||||
numDownloadParts = 16
|
||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||
)
|
||||
|
||||
// isHuggingFaceURL returns true if the URL is from a HuggingFace domain.
|
||||
// This includes:
|
||||
// - huggingface.co (main domain)
|
||||
// - *.huggingface.co (subdomains like cdn-lfs.huggingface.co)
|
||||
// - hf.co (shortlink domain)
|
||||
// - *.hf.co (CDN domains like cdn-lfs.hf.co, cdn-lfs3.hf.co)
|
||||
func isHuggingFaceURL(u *url.URL) bool {
|
||||
if u == nil {
|
||||
return false
|
||||
}
|
||||
host := strings.ToLower(u.Hostname())
|
||||
return host == "huggingface.co" ||
|
||||
strings.HasSuffix(host, ".huggingface.co") ||
|
||||
host == "hf.co" ||
|
||||
strings.HasSuffix(host, ".hf.co")
|
||||
}
|
||||
|
||||
// getNumDownloadParts returns the number of concurrent download parts to use
|
||||
// for the given URL. HuggingFace URLs use reduced concurrency (default 4) to
|
||||
// avoid triggering rate limits. This can be overridden via the OLLAMA_HF_CONCURRENCY
|
||||
// environment variable. For non-HuggingFace URLs, returns the standard concurrency (16).
|
||||
func getNumDownloadParts(u *url.URL) int {
|
||||
if isHuggingFaceURL(u) {
|
||||
if v := os.Getenv("OLLAMA_HF_CONCURRENCY"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return numHFDownloadParts
|
||||
}
|
||||
return numDownloadParts
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Name() string {
|
||||
return strings.Join([]string{
|
||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||
@@ -308,11 +271,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
}
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
concurrency := getNumDownloadParts(directURL)
|
||||
if concurrency != numDownloadParts {
|
||||
slog.Info(fmt.Sprintf("using reduced concurrency (%d) for HuggingFace download", concurrency))
|
||||
}
|
||||
g.SetLimit(concurrency)
|
||||
g.SetLimit(numDownloadParts)
|
||||
for i := range b.Parts {
|
||||
part := b.Parts[i]
|
||||
if part.Completed.Load() == part.Size {
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsHuggingFaceURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil url",
|
||||
url: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "huggingface.co main domain",
|
||||
url: "https://huggingface.co/some/model",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cdn-lfs.huggingface.co subdomain",
|
||||
url: "https://cdn-lfs.huggingface.co/repos/abc/123",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cdn-lfs3.hf.co CDN domain",
|
||||
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "hf.co shortlink domain",
|
||||
url: "https://hf.co/model",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase HuggingFace domain",
|
||||
url: "https://HUGGINGFACE.CO/model",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "mixed case HF domain",
|
||||
url: "https://Cdn-Lfs.HF.Co/repos",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ollama registry",
|
||||
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "github.com",
|
||||
url: "https://github.com/ollama/ollama",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "fake huggingface domain",
|
||||
url: "https://nothuggingface.co/model",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "fake hf domain",
|
||||
url: "https://nothf.co/model",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "huggingface in path not host",
|
||||
url: "https://example.com/huggingface.co/model",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var u *url.URL
|
||||
if tc.url != "" {
|
||||
var err error
|
||||
u, err = url.Parse(tc.url)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse URL: %v", err)
|
||||
}
|
||||
}
|
||||
got := isHuggingFaceURL(u)
|
||||
assert.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNumDownloadParts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
envValue string
|
||||
expected int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "nil url returns default",
|
||||
url: "",
|
||||
envValue: "",
|
||||
expected: numDownloadParts,
|
||||
description: "nil URL should return standard concurrency",
|
||||
},
|
||||
{
|
||||
name: "ollama registry returns default",
|
||||
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||
envValue: "",
|
||||
expected: numDownloadParts,
|
||||
description: "Ollama registry should use standard concurrency",
|
||||
},
|
||||
{
|
||||
name: "huggingface returns reduced default",
|
||||
url: "https://huggingface.co/model/repo",
|
||||
envValue: "",
|
||||
expected: numHFDownloadParts,
|
||||
description: "HuggingFace should use reduced concurrency",
|
||||
},
|
||||
{
|
||||
name: "hf.co CDN returns reduced default",
|
||||
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
||||
envValue: "",
|
||||
expected: numHFDownloadParts,
|
||||
description: "HuggingFace CDN should use reduced concurrency",
|
||||
},
|
||||
{
|
||||
name: "huggingface with env override",
|
||||
url: "https://huggingface.co/model/repo",
|
||||
envValue: "2",
|
||||
expected: 2,
|
||||
description: "OLLAMA_HF_CONCURRENCY should override default",
|
||||
},
|
||||
{
|
||||
name: "huggingface with higher env override",
|
||||
url: "https://huggingface.co/model/repo",
|
||||
envValue: "8",
|
||||
expected: 8,
|
||||
description: "OLLAMA_HF_CONCURRENCY can be set higher than default",
|
||||
},
|
||||
{
|
||||
name: "huggingface with invalid env (non-numeric)",
|
||||
url: "https://huggingface.co/model/repo",
|
||||
envValue: "invalid",
|
||||
expected: numHFDownloadParts,
|
||||
description: "Invalid OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||
},
|
||||
{
|
||||
name: "huggingface with invalid env (zero)",
|
||||
url: "https://huggingface.co/model/repo",
|
||||
envValue: "0",
|
||||
expected: numHFDownloadParts,
|
||||
description: "Zero OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||
},
|
||||
{
|
||||
name: "huggingface with invalid env (negative)",
|
||||
url: "https://huggingface.co/model/repo",
|
||||
envValue: "-1",
|
||||
expected: numHFDownloadParts,
|
||||
description: "Negative OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||
},
|
||||
{
|
||||
name: "non-huggingface ignores env",
|
||||
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||
envValue: "2",
|
||||
expected: numDownloadParts,
|
||||
description: "OLLAMA_HF_CONCURRENCY should not affect non-HF URLs",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Set or clear the environment variable
|
||||
if tc.envValue != "" {
|
||||
t.Setenv("OLLAMA_HF_CONCURRENCY", tc.envValue)
|
||||
}
|
||||
|
||||
var u *url.URL
|
||||
if tc.url != "" {
|
||||
var err error
|
||||
u, err = url.Parse(tc.url)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse URL: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
got := getNumDownloadParts(u)
|
||||
assert.Equal(t, tc.expected, got, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
})
|
||||
}, base.Host)
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
|
||||
// Handle authentication error with one retry
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -51,7 +51,6 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,29 +163,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -214,12 +190,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -1587,13 +1557,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// Experimental OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", s.handleImageGeneration)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -1911,6 +1880,62 @@ func toolCallId() string {
|
||||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGeneration(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size (e.g., "1024x768") into width and height
|
||||
width, height := int32(1024), int32(1024)
|
||||
if req.Size != "" {
|
||||
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var image []byte
|
||||
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
}, func(resp llm.CompletionResponse) {
|
||||
if len(resp.Image) > 0 {
|
||||
image = resp.Image
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Load both an imagegen runner and a language model runner
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
||||
modelPath: "/fake/flux/model",
|
||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
langModelRunner := &runnerRef{
|
||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
||||
modelPath: "/fake/llama3/model",
|
||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify both are loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify updateFreeSpace accounts for both
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
||||
TotalMemory: 24 * format.GigaByte,
|
||||
FreeMemory: 24 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Free memory should be reduced by both models
|
||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
50
x/README.md
@@ -1,50 +0,0 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
|
||||
|
||||
### Building ollama-mlx
|
||||
|
||||
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
|
||||
|
||||
#### macOS (Apple Silicon and Intel)
|
||||
|
||||
```bash
|
||||
# Build MLX backend libraries
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
|
||||
# Build ollama-mlx binary
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
#### Linux (CUDA)
|
||||
|
||||
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
|
||||
|
||||
```bash
|
||||
# Build MLX backend libraries with CUDA support
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
|
||||
# Build ollama-mlx binary
|
||||
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
|
||||
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
#### Using build scripts
|
||||
|
||||
The build scripts automatically create the `ollama-mlx` binary:
|
||||
|
||||
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
|
||||
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
|
||||
|
||||
## Image Generation
|
||||
|
||||
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.
|
||||
67
x/cmd/run.go
@@ -25,14 +25,6 @@ import (
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// MultilineState tracks the state of multiline input
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilineSystem
|
||||
)
|
||||
|
||||
// Tool output capping constants
|
||||
const (
|
||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||
@@ -656,7 +648,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -707,7 +699,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
var multiline MultilineState = MultilineNone
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -721,37 +712,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
multiline = MultilineNone
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case multiline != MultilineNone:
|
||||
// check if there's a multiline terminating string
|
||||
before, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
continue
|
||||
}
|
||||
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
scanner.Prompt.UseAlt = false
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
@@ -860,41 +826,18 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
|
||||
fmt.Println("Usage: /set system <message>")
|
||||
continue
|
||||
}
|
||||
|
||||
multiline = MultilineSystem
|
||||
|
||||
line := strings.Join(args[2:], " ")
|
||||
line, ok := strings.CutPrefix(line, `"""`)
|
||||
if !ok {
|
||||
multiline = MultilineNone
|
||||
} else {
|
||||
// only cut suffix if the line is multiline
|
||||
line, ok = strings.CutSuffix(line, `"""`)
|
||||
if ok {
|
||||
multiline = MultilineNone
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
if multiline != MultilineNone {
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
system = strings.Join(args[2:], " ")
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
@@ -1081,7 +1024,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 && multiline == MultilineNone {
|
||||
if sb.Len() > 0 {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractBase64(content string) string {
|
||||
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
return content[13:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imageBase64 == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
// URL format not supported when using base64 transfer
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
} else {
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
@@ -7,7 +7,6 @@ package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -39,77 +38,17 @@ func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Steps: 0, // 0 means model default
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
cmd.Flags().MarkHidden("width")
|
||||
@@ -152,23 +91,18 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
|
||||
}
|
||||
|
||||
// generateImageWithOptions generates an image with the given options.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
// Note: opts are currently unused as the native API doesn't support size parameters.
|
||||
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
@@ -48,7 +49,7 @@ func main() {
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
|
||||
@@ -175,3 +175,63 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
@@ -95,8 +95,3 @@ func EstimateVRAM(modelName string) uint64 {
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
}
|
||||
|
||||
@@ -94,13 +94,6 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -172,7 +173,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 30
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
|
||||
@@ -194,7 +194,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Turbo default
|
||||
cfg.Steps = 9 // Z-Image turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
|
||||
@@ -136,16 +136,8 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Apply defaults
|
||||
if req.Width <= 0 {
|
||||
req.Width = 1024
|
||||
}
|
||||
if req.Height <= 0 {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
req.Steps = 9
|
||||
}
|
||||
// Model applies its own defaults for width/height/steps
|
||||
// Only seed needs to be set here if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -25,6 +26,11 @@ import (
|
||||
)
|
||||
|
||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||
//
|
||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||
// like any other model. The plan is to eventually bring this into the llm/ package
|
||||
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
|
||||
// separate allows for independent iteration on image generation support.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
@@ -37,22 +43,6 @@ type Server struct {
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// completionRequest is sent to the subprocess
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// completionResponse is received from the subprocess
|
||||
type completionResponse struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
@@ -139,7 +129,6 @@ func NewServer(modelName string) (*Server, error) {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
// Capture last error line for better error reporting
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
@@ -171,7 +160,6 @@ func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load is called by the scheduler after the server is created.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -204,20 +192,16 @@ func (s *Server) waitUntilRunning() error {
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
// Include last stderr line for better error context
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||
// Include recent stderr lines for better error context
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
case <-ticker.C:
|
||||
@@ -229,44 +213,39 @@ func (s *Server) waitUntilRunning() error {
|
||||
}
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
// getLastErr returns the last stderr line.
|
||||
func (s *Server) getLastErr() string {
|
||||
s.lastErrLock.Lock()
|
||||
defer s.lastErrLock.Unlock()
|
||||
return s.lastErr
|
||||
}
|
||||
|
||||
// Completion generates an image from the prompt via the subprocess.
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
|
||||
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Build request
|
||||
creq := completionRequest{
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Seed: time.Now().UnixNano(),
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Seed: seed,
|
||||
}
|
||||
|
||||
if req.Options != nil {
|
||||
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||
creq.Width = int32(req.Options.NumCtx)
|
||||
}
|
||||
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||
creq.Height = int32(req.Options.NumGPU)
|
||||
}
|
||||
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||
creq.Steps = req.Options.NumPredict
|
||||
}
|
||||
if req.Options.Seed > 0 {
|
||||
creq.Seed = int64(req.Options.Seed)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode request body
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send request to subprocess
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
@@ -281,30 +260,40 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Stream responses - use large buffer for base64 image data
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||
for scanner.Scan() {
|
||||
var cresp completionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||
// Parse subprocess response (has singular "image" field)
|
||||
var raw struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content := cresp.Content
|
||||
// If this is the final response with an image, encode it in the content
|
||||
if cresp.Done && cresp.Image != "" {
|
||||
content = "IMAGE_BASE64:" + cresp.Image
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
Total: raw.Total,
|
||||
}
|
||||
if raw.Image != "" {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
|
||||
cresp.Image = data
|
||||
}
|
||||
}
|
||||
|
||||
fn(llm.CompletionResponse{
|
||||
Content: content,
|
||||
Done: cresp.Done,
|
||||
})
|
||||
fn(cresp)
|
||||
if cresp.Done {
|
||||
break
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,22 +335,18 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding is not supported for image generation models.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||
return nil, 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
// Tokenize is not supported for image generation models.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("tokenize not supported for image generation models")
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
// Detokenize is not supported for image generation models.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("detokenize not supported for image generation models")
|
||||
return "", errors.New("not supported")
|
||||
}
|
||||
|
||||
// Pid returns the subprocess PID.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -371,17 +356,9 @@ func (s *Server) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// GetPort returns the subprocess port.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
func (s *Server) GetPort() int { return s.port }
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
|
||||
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns true if the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
|
||||