Compare commits
2 Commits
hoyyeva/up
...
parth/decr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b2abfb433 | ||
|
|
805ed4644c |
@@ -190,7 +190,7 @@ if(MLX_ENGINE)
|
|||||||
install(TARGETS mlx mlxc
|
install(TARGETS mlx mlxc
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
|
|||||||
43
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
|||||||
|
|
||||||
## Model library
|
## Model library
|
||||||
|
|
||||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||||
|
|
||||||
Here are some example models that can be downloaded:
|
Here are some example models that can be downloaded:
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||||
@@ -260,38 +260,6 @@ Finally, in a separate shell, run a model:
|
|||||||
./ollama run llama3.2
|
./ollama run llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
## Building with MLX (experimental)
|
|
||||||
|
|
||||||
First build the MLX libraries:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cmake --preset MLX
|
|
||||||
cmake --build --preset MLX --parallel
|
|
||||||
cmake --install build --component MLX
|
|
||||||
```
|
|
||||||
|
|
||||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
|
||||||
|
|
||||||
```shell
|
|
||||||
go build -tags mlx -o ollama-mlx .
|
|
||||||
```
|
|
||||||
|
|
||||||
Finally, start the server:
|
|
||||||
|
|
||||||
```
|
|
||||||
./ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
### Building MLX with CUDA
|
|
||||||
|
|
||||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cmake --preset 'MLX CUDA 13'
|
|
||||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
|
||||||
cmake --install build --component MLX
|
|
||||||
```
|
|
||||||
|
|
||||||
## REST API
|
## REST API
|
||||||
|
|
||||||
Ollama has a REST API for running and managing models.
|
Ollama has a REST API for running and managing models.
|
||||||
@@ -322,7 +290,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
### Web & Desktop
|
### Web & Desktop
|
||||||
|
|
||||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
|
||||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||||
@@ -454,7 +421,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||||
@@ -526,7 +493,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
### Database
|
### Database
|
||||||
|
|
||||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||||
@@ -669,7 +636,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
### Observability
|
### Observability
|
||||||
|
|
||||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||||
@@ -678,5 +644,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||||
|
|
||||||
### Security
|
### Security
|
||||||
|
|
||||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||||
|
|||||||
@@ -253,8 +253,6 @@ func main() {
|
|||||||
done <- osrv.Run(octx)
|
done <- osrv.Run(octx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
upd := &updater.Updater{Store: st}
|
|
||||||
|
|
||||||
uiServer := ui.Server{
|
uiServer := ui.Server{
|
||||||
Token: token,
|
Token: token,
|
||||||
Restart: func() {
|
Restart: func() {
|
||||||
@@ -269,10 +267,6 @@ func main() {
|
|||||||
ToolRegistry: toolRegistry,
|
ToolRegistry: toolRegistry,
|
||||||
Dev: devMode,
|
Dev: devMode,
|
||||||
Logger: slog.Default(),
|
Logger: slog.Default(),
|
||||||
Updater: upd,
|
|
||||||
UpdateAvailableFunc: func() {
|
|
||||||
UpdateAvailable("")
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
@@ -290,13 +284,8 @@ func main() {
|
|||||||
slog.Debug("background desktop server done")
|
slog.Debug("background desktop server done")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
updater := &updater.Updater{Store: st}
|
||||||
|
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||||
// Check for pending updates on startup (show tray notification if update is ready)
|
|
||||||
if updater.IsUpdatePending() {
|
|
||||||
slog.Debug("update pending on startup, showing tray notification")
|
|
||||||
UpdateAvailable("")
|
|
||||||
}
|
|
||||||
|
|
||||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -359,18 +348,6 @@ func startHiddenTasks() {
|
|||||||
// CLI triggered app startup use-case
|
// CLI triggered app startup use-case
|
||||||
slog.Info("deferring pending update for fast startup")
|
slog.Info("deferring pending update for fast startup")
|
||||||
} else {
|
} else {
|
||||||
// Check if auto-update is enabled before automatically upgrading
|
|
||||||
st := &store.Store{}
|
|
||||||
settings, err := st.Settings()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
|
||||||
} else if !settings.AutoUpdateEnabled {
|
|
||||||
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
|
|
||||||
// Still show tray notification so user knows update is ready
|
|
||||||
UpdateAvailable("")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := updater.DoUpgradeAtStartup(); err != nil {
|
if err := updater.DoUpgradeAtStartup(); err != nil {
|
||||||
slog.Info("unable to perform upgrade at startup", "error", err)
|
slog.Info("unable to perform upgrade at startup", "error", err)
|
||||||
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ extern NSString *SystemWidePath;
|
|||||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||||
@property(assign, nonatomic) BOOL updateAvailable;
|
@property(assign, nonatomic) BOOL updateAvailable;
|
||||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
@implementation AppDelegate
|
@implementation AppDelegate
|
||||||
@@ -41,13 +40,6 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
|||||||
}
|
}
|
||||||
|
|
||||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||||
// Register for system shutdown/restart notification so we can allow termination
|
|
||||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
|
||||||
addObserver:self
|
|
||||||
selector:@selector(systemWillPowerOff:)
|
|
||||||
name:NSWorkspaceWillPowerOffNotification
|
|
||||||
object:nil];
|
|
||||||
|
|
||||||
// if we're in development mode, set the app icon
|
// if we're in development mode, set the app icon
|
||||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||||
if (![bundlePath hasSuffix:@".app"]) {
|
if (![bundlePath hasSuffix:@".app"]) {
|
||||||
@@ -286,18 +278,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
|||||||
[NSApp activateIgnoringOtherApps:YES];
|
[NSApp activateIgnoringOtherApps:YES];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
|
||||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
|
||||||
// The system will call applicationShouldTerminate: after posting this notification.
|
|
||||||
self.systemShutdownInProgress = YES;
|
|
||||||
}
|
|
||||||
|
|
||||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||||
// Allow termination if the system is shutting down or restarting
|
|
||||||
if (self.systemShutdownInProgress) {
|
|
||||||
return NSTerminateNow;
|
|
||||||
}
|
|
||||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
|
||||||
[NSApp hide:nil];
|
[NSApp hide:nil];
|
||||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||||
return NSTerminateCancel;
|
return NSTerminateCancel;
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
sqlite3 "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// currentSchemaVersion defines the current database schema version.
|
// currentSchemaVersion defines the current database schema version.
|
||||||
// Increment this when making schema changes that require migrations.
|
// Increment this when making schema changes that require migrations.
|
||||||
const currentSchemaVersion = 13
|
const currentSchemaVersion = 12
|
||||||
|
|
||||||
// database wraps the SQLite connection.
|
// database wraps the SQLite connection.
|
||||||
// SQLite handles its own locking for concurrent access:
|
// SQLite handles its own locking for concurrent access:
|
||||||
@@ -85,7 +85,6 @@ func (db *database) init() error {
|
|||||||
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||||
think_level TEXT NOT NULL DEFAULT '',
|
think_level TEXT NOT NULL DEFAULT '',
|
||||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||||
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
|
|
||||||
schema_version INTEGER NOT NULL DEFAULT %d
|
schema_version INTEGER NOT NULL DEFAULT %d
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -245,12 +244,6 @@ func (db *database) migrate() error {
|
|||||||
return fmt.Errorf("migrate v11 to v12: %w", err)
|
return fmt.Errorf("migrate v11 to v12: %w", err)
|
||||||
}
|
}
|
||||||
version = 12
|
version = 12
|
||||||
case 12:
|
|
||||||
// add auto_update_enabled column to settings table
|
|
||||||
if err := db.migrateV12ToV13(); err != nil {
|
|
||||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
|
||||||
}
|
|
||||||
version = 13
|
|
||||||
default:
|
default:
|
||||||
// If we have a version we don't recognize, just set it to current
|
// If we have a version we don't recognize, just set it to current
|
||||||
// This might happen during development
|
// This might happen during development
|
||||||
@@ -459,21 +452,6 @@ func (db *database) migrateV11ToV12() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrateV12ToV13 adds the auto_update_enabled column to the settings table
|
|
||||||
func (db *database) migrateV12ToV13() error {
|
|
||||||
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
|
|
||||||
if err != nil && !duplicateColumnError(err) {
|
|
||||||
return fmt.Errorf("add auto_update_enabled column: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("update schema version: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||||
func (db *database) cleanupOrphanedData() error {
|
func (db *database) cleanupOrphanedData() error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
@@ -504,11 +482,19 @@ func (db *database) cleanupOrphanedData() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func duplicateColumnError(err error) bool {
|
func duplicateColumnError(err error) bool {
|
||||||
return err != nil && strings.Contains(err.Error(), "duplicate column name")
|
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||||
|
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||||
|
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func columnNotExists(err error) bool {
|
func columnNotExists(err error) bool {
|
||||||
return err != nil && strings.Contains(err.Error(), "no such column")
|
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||||
|
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||||
|
strings.Contains(sqlite3Err.Error(), "no such column")
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) getAllChats() ([]Chat, error) {
|
func (db *database) getAllChats() ([]Chat, error) {
|
||||||
@@ -1122,9 +1108,9 @@ func (db *database) getSettings() (Settings, error) {
|
|||||||
var s Settings
|
var s Settings
|
||||||
|
|
||||||
err := db.conn.QueryRow(`
|
err := db.conn.QueryRow(`
|
||||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||||
FROM settings
|
FROM settings
|
||||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1135,8 +1121,8 @@ func (db *database) getSettings() (Settings, error) {
|
|||||||
func (db *database) setSettings(s Settings) error {
|
func (db *database) setSettings(s Settings) error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
UPDATE settings
|
UPDATE settings
|
||||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set settings: %w", err)
|
return fmt.Errorf("set settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -169,9 +169,6 @@ type Settings struct {
|
|||||||
|
|
||||||
// SidebarOpen indicates if the chat sidebar is open
|
// SidebarOpen indicates if the chat sidebar is open
|
||||||
SidebarOpen bool
|
SidebarOpen bool
|
||||||
|
|
||||||
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
|
||||||
AutoUpdateEnabled bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
|
|||||||
@@ -413,7 +413,6 @@ export class Settings {
|
|||||||
ThinkLevel: string;
|
ThinkLevel: string;
|
||||||
SelectedModel: string;
|
SelectedModel: string;
|
||||||
SidebarOpen: boolean;
|
SidebarOpen: boolean;
|
||||||
AutoUpdateEnabled: boolean;
|
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
@@ -432,7 +431,6 @@ export class Settings {
|
|||||||
this.ThinkLevel = source["ThinkLevel"];
|
this.ThinkLevel = source["ThinkLevel"];
|
||||||
this.SelectedModel = source["SelectedModel"];
|
this.SelectedModel = source["SelectedModel"];
|
||||||
this.SidebarOpen = source["SidebarOpen"];
|
this.SidebarOpen = source["SidebarOpen"];
|
||||||
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class SettingsResponse {
|
export class SettingsResponse {
|
||||||
@@ -469,46 +467,6 @@ export class HealthResponse {
|
|||||||
this.healthy = source["healthy"];
|
this.healthy = source["healthy"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class UpdateInfo {
|
|
||||||
currentVersion: string;
|
|
||||||
availableVersion: string;
|
|
||||||
updateAvailable: boolean;
|
|
||||||
updateDownloaded: boolean;
|
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
|
||||||
this.currentVersion = source["currentVersion"];
|
|
||||||
this.availableVersion = source["availableVersion"];
|
|
||||||
this.updateAvailable = source["updateAvailable"];
|
|
||||||
this.updateDownloaded = source["updateDownloaded"];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
export class UpdateCheckResponse {
|
|
||||||
updateInfo: UpdateInfo;
|
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
|
||||||
this.updateInfo = this.convertValues(source["updateInfo"], UpdateInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
|
||||||
if (!a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
if (Array.isArray(a)) {
|
|
||||||
return (a as any[]).map(elem => this.convertValues(elem, classs));
|
|
||||||
} else if ("object" === typeof a) {
|
|
||||||
if (asMap) {
|
|
||||||
for (const key of Object.keys(a)) {
|
|
||||||
a[key] = new classs(a[key]);
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
return new classs(a);
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
export class User {
|
export class User {
|
||||||
id: string;
|
id: string;
|
||||||
email: string;
|
email: string;
|
||||||
|
|||||||
@@ -414,54 +414,3 @@ export async function fetchHealth(): Promise<boolean> {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getCurrentVersion(): Promise<string> {
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${API_BASE}/api/version`, {
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
if (response.ok) {
|
|
||||||
const data = await response.json();
|
|
||||||
return data.version || "Unknown";
|
|
||||||
}
|
|
||||||
return "Unknown";
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error fetching version:", error);
|
|
||||||
return "Unknown";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function checkForUpdate(): Promise<{
|
|
||||||
currentVersion: string;
|
|
||||||
availableVersion: string;
|
|
||||||
updateAvailable: boolean;
|
|
||||||
updateDownloaded: boolean;
|
|
||||||
}> {
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/update/check`, {
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error("Failed to check for update");
|
|
||||||
}
|
|
||||||
const data = await response.json();
|
|
||||||
return data.updateInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function installUpdate(): Promise<void> {
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/update/install`, {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
if (!response.ok) {
|
|
||||||
const error = await response.text();
|
|
||||||
throw new Error(error || "Failed to install update");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -14,13 +14,12 @@ import {
|
|||||||
XMarkIcon,
|
XMarkIcon,
|
||||||
CogIcon,
|
CogIcon,
|
||||||
ArrowLeftIcon,
|
ArrowLeftIcon,
|
||||||
ArrowDownTrayIcon,
|
|
||||||
} from "@heroicons/react/20/solid";
|
} from "@heroicons/react/20/solid";
|
||||||
import { Settings as SettingsType } from "@/gotypes";
|
import { Settings as SettingsType } from "@/gotypes";
|
||||||
import { useNavigate } from "@tanstack/react-router";
|
import { useNavigate } from "@tanstack/react-router";
|
||||||
import { useUser } from "@/hooks/useUser";
|
import { useUser } from "@/hooks/useUser";
|
||||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||||
import { getSettings, updateSettings, checkForUpdate } from "@/api";
|
import { getSettings, updateSettings } from "@/api";
|
||||||
|
|
||||||
function AnimatedDots() {
|
function AnimatedDots() {
|
||||||
return (
|
return (
|
||||||
@@ -40,12 +39,6 @@ export default function Settings() {
|
|||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const [showSaved, setShowSaved] = useState(false);
|
const [showSaved, setShowSaved] = useState(false);
|
||||||
const [restartMessage, setRestartMessage] = useState(false);
|
const [restartMessage, setRestartMessage] = useState(false);
|
||||||
const [updateInfo, setUpdateInfo] = useState<{
|
|
||||||
currentVersion: string;
|
|
||||||
availableVersion: string;
|
|
||||||
updateAvailable: boolean;
|
|
||||||
updateDownloaded: boolean;
|
|
||||||
} | null>(null);
|
|
||||||
const {
|
const {
|
||||||
user,
|
user,
|
||||||
isAuthenticated,
|
isAuthenticated,
|
||||||
@@ -83,10 +76,6 @@ export default function Settings() {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
refetchUser();
|
refetchUser();
|
||||||
// Check for updates on mount
|
|
||||||
checkForUpdate()
|
|
||||||
.then(setUpdateInfo)
|
|
||||||
.catch((err) => console.error("Error checking for update:", err));
|
|
||||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -355,58 +344,6 @@ export default function Settings() {
|
|||||||
{/* Local Configuration */}
|
{/* Local Configuration */}
|
||||||
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||||
<div className="space-y-4 p-4">
|
<div className="space-y-4 p-4">
|
||||||
{/* Auto Update */}
|
|
||||||
<Field>
|
|
||||||
<div className="flex items-start justify-between gap-4">
|
|
||||||
<div className="flex items-start space-x-3 flex-1">
|
|
||||||
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
|
||||||
<div className="flex-1">
|
|
||||||
<Label>Auto-download updates</Label>
|
|
||||||
<Description>
|
|
||||||
{settings.AutoUpdateEnabled ? (
|
|
||||||
<>
|
|
||||||
Automatically downloads updates when available.
|
|
||||||
<div className="mt-2 text-xs text-zinc-600 dark:text-zinc-400">
|
|
||||||
Current version: {updateInfo?.currentVersion || "Loading..."}
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
Manually download updates.
|
|
||||||
<div className="mt-3 p-3 bg-zinc-50 dark:bg-zinc-900 rounded-lg border border-zinc-200 dark:border-zinc-800">
|
|
||||||
<div className="space-y-2 text-sm">
|
|
||||||
<div className="flex justify-between">
|
|
||||||
<span className="text-zinc-600 dark:text-zinc-400">Current version: {updateInfo?.currentVersion || "Loading..."}</span>
|
|
||||||
</div>
|
|
||||||
{updateInfo?.availableVersion && (
|
|
||||||
<div className="flex justify-between">
|
|
||||||
<span className="text-zinc-600 dark:text-zinc-400">Available version: {updateInfo?.availableVersion}</span>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<a
|
|
||||||
href="https://ollama.com/download"
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
className="mt-3 inline-block text-sm text-neutral-600 dark:text-neutral-400 underline"
|
|
||||||
>
|
|
||||||
Download new version →
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</Description>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div className="flex-shrink-0">
|
|
||||||
<Switch
|
|
||||||
checked={settings.AutoUpdateEnabled}
|
|
||||||
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{/* Expose Ollama */}
|
{/* Expose Ollama */}
|
||||||
<Field>
|
<Field>
|
||||||
<div className="flex items-start justify-between gap-4">
|
<div className="flex items-start justify-between gap-4">
|
||||||
|
|||||||
@@ -100,17 +100,6 @@ type HealthResponse struct {
|
|||||||
Healthy bool `json:"healthy"`
|
Healthy bool `json:"healthy"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateInfo struct {
|
|
||||||
CurrentVersion string `json:"currentVersion"`
|
|
||||||
AvailableVersion string `json:"availableVersion"`
|
|
||||||
UpdateAvailable bool `json:"updateAvailable"`
|
|
||||||
UpdateDownloaded bool `json:"updateDownloaded"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UpdateCheckResponse struct {
|
|
||||||
UpdateInfo UpdateInfo `json:"updateInfo"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
|||||||
100
app/ui/ui.go
@@ -28,7 +28,6 @@ import (
|
|||||||
"github.com/ollama/ollama/app/tools"
|
"github.com/ollama/ollama/app/tools"
|
||||||
"github.com/ollama/ollama/app/types/not"
|
"github.com/ollama/ollama/app/types/not"
|
||||||
"github.com/ollama/ollama/app/ui/responses"
|
"github.com/ollama/ollama/app/ui/responses"
|
||||||
"github.com/ollama/ollama/app/updater"
|
|
||||||
"github.com/ollama/ollama/app/version"
|
"github.com/ollama/ollama/app/version"
|
||||||
ollamaAuth "github.com/ollama/ollama/auth"
|
ollamaAuth "github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
@@ -107,18 +106,6 @@ type Server struct {
|
|||||||
|
|
||||||
// Dev is true if the server is running in development mode
|
// Dev is true if the server is running in development mode
|
||||||
Dev bool
|
Dev bool
|
||||||
|
|
||||||
// Updater for checking and downloading updates
|
|
||||||
Updater UpdaterInterface
|
|
||||||
UpdateAvailableFunc func()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdaterInterface defines the methods we need from the updater
|
|
||||||
type UpdaterInterface interface {
|
|
||||||
CheckForUpdate(ctx context.Context) (bool, string, error)
|
|
||||||
InstallAndRestart() error
|
|
||||||
CancelOngoingDownload()
|
|
||||||
TriggerImmediateCheck()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) log() *slog.Logger {
|
func (s *Server) log() *slog.Logger {
|
||||||
@@ -297,8 +284,6 @@ func (s *Server) Handler() http.Handler {
|
|||||||
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
||||||
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
||||||
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
||||||
mux.Handle("GET /api/v1/update/check", handle(s.checkForUpdate))
|
|
||||||
mux.Handle("POST /api/v1/update/install", handle(s.installUpdate))
|
|
||||||
|
|
||||||
// Ollama proxy endpoints
|
// Ollama proxy endpoints
|
||||||
ollamaProxy := s.ollamaProxy()
|
ollamaProxy := s.ollamaProxy()
|
||||||
@@ -1463,24 +1448,6 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return fmt.Errorf("failed to save settings: %w", err)
|
return fmt.Errorf("failed to save settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle auto-update toggle changes
|
|
||||||
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
|
|
||||||
if !settings.AutoUpdateEnabled {
|
|
||||||
// Auto-update disabled: cancel any ongoing download
|
|
||||||
if s.Updater != nil {
|
|
||||||
s.Updater.CancelOngoingDownload()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
|
|
||||||
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
|
|
||||||
s.UpdateAvailableFunc()
|
|
||||||
} else if s.Updater != nil {
|
|
||||||
// Trigger the background checker to run immediately
|
|
||||||
s.Updater.TriggerImmediateCheck()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if old.ContextLength != settings.ContextLength ||
|
if old.ContextLength != settings.ContextLength ||
|
||||||
old.Models != settings.Models ||
|
old.Models != settings.Models ||
|
||||||
old.Expose != settings.Expose {
|
old.Expose != settings.Expose {
|
||||||
@@ -1557,73 +1524,6 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return json.NewEncoder(w).Encode(response)
|
return json.NewEncoder(w).Encode(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) checkForUpdate(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
currentVersion := version.Version
|
|
||||||
|
|
||||||
if s.Updater == nil {
|
|
||||||
return fmt.Errorf("updater not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
updateAvailable, updateVersion, err := s.Updater.CheckForUpdate(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
s.log().Warn("failed to check for update", "error", err)
|
|
||||||
// Don't return error, just log it and continue with no update available
|
|
||||||
}
|
|
||||||
|
|
||||||
response := responses.UpdateCheckResponse{
|
|
||||||
UpdateInfo: responses.UpdateInfo{
|
|
||||||
CurrentVersion: currentVersion,
|
|
||||||
AvailableVersion: updateVersion,
|
|
||||||
UpdateAvailable: updateAvailable,
|
|
||||||
UpdateDownloaded: updater.UpdateDownloaded,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
return json.NewEncoder(w).Encode(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) installUpdate(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != "POST" {
|
|
||||||
return fmt.Errorf("method not allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Updater == nil {
|
|
||||||
s.log().Error("install failed: updater not available")
|
|
||||||
return fmt.Errorf("updater not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if update is downloaded
|
|
||||||
if !updater.UpdateDownloaded {
|
|
||||||
s.log().Error("install failed: no update downloaded")
|
|
||||||
return fmt.Errorf("no update downloaded")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send response before restarting
|
|
||||||
response := map[string]any{
|
|
||||||
"success": true,
|
|
||||||
"message": "Installing update and restarting...",
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give the response time to be sent
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// Trigger the upgrade and restart
|
|
||||||
go func() {
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
if err := s.Updater.InstallAndRestart(); err != nil {
|
|
||||||
s.log().Error("failed to install update", "error", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func userAgent() string {
|
func userAgent() string {
|
||||||
buildinfo, _ := debug.ReadBuildInfo()
|
buildinfo, _ := debug.ReadBuildInfo()
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
@@ -59,8 +58,7 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
|||||||
query := requestURL.Query()
|
query := requestURL.Query()
|
||||||
query.Add("os", runtime.GOOS)
|
query.Add("os", runtime.GOOS)
|
||||||
query.Add("arch", runtime.GOARCH)
|
query.Add("arch", runtime.GOARCH)
|
||||||
currentVersion := version.Version
|
query.Add("version", version.Version)
|
||||||
query.Add("version", currentVersion)
|
|
||||||
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
|
||||||
// The original macOS app used to use the device ID
|
// The original macOS app used to use the device ID
|
||||||
@@ -133,27 +131,15 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||||
// Create a cancellable context for this download
|
|
||||||
downloadCtx, cancel := context.WithCancel(ctx)
|
|
||||||
u.cancelDownloadLock.Lock()
|
|
||||||
u.cancelDownload = cancel
|
|
||||||
u.cancelDownloadLock.Unlock()
|
|
||||||
defer func() {
|
|
||||||
u.cancelDownloadLock.Lock()
|
|
||||||
u.cancelDownload = nil
|
|
||||||
u.cancelDownloadLock.Unlock()
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Do a head first to check etag info
|
// Do a head first to check etag info
|
||||||
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// In case of slow downloads, continue the update check in the background
|
// In case of slow downloads, continue the update check in the background
|
||||||
bgctx, bgcancel := context.WithCancel(downloadCtx)
|
bgctx, cancel := context.WithCancel(ctx)
|
||||||
defer bgcancel()
|
defer cancel()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -190,7 +176,6 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
|||||||
_, err = os.Stat(stageFilename)
|
_, err = os.Stat(stageFilename)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
slog.Info("update already downloaded", "bundle", stageFilename)
|
slog.Info("update already downloaded", "bundle", stageFilename)
|
||||||
UpdateDownloaded = true
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,95 +244,34 @@ func cleanupOldDownloads(stageDir string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Updater struct {
|
type Updater struct {
|
||||||
Store *store.Store
|
Store *store.Store
|
||||||
cancelDownload context.CancelFunc
|
|
||||||
cancelDownloadLock sync.Mutex
|
|
||||||
checkNow chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CancelOngoingDownload cancels any currently running download
|
|
||||||
func (u *Updater) CancelOngoingDownload() {
|
|
||||||
u.cancelDownloadLock.Lock()
|
|
||||||
defer u.cancelDownloadLock.Unlock()
|
|
||||||
if u.cancelDownload != nil {
|
|
||||||
slog.Info("cancelling ongoing update download")
|
|
||||||
u.cancelDownload()
|
|
||||||
u.cancelDownload = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TriggerImmediateCheck signals the background checker to check for updates immediately
|
|
||||||
func (u *Updater) TriggerImmediateCheck() {
|
|
||||||
if u.checkNow != nil {
|
|
||||||
u.checkNow <- struct{}{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||||
u.checkNow = make(chan struct{}, 1)
|
|
||||||
go func() {
|
go func() {
|
||||||
// Don't blast an update message immediately after startup
|
// Don't blast an update message immediately after startup
|
||||||
time.Sleep(UpdateCheckInitialDelay)
|
time.Sleep(UpdateCheckInitialDelay)
|
||||||
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
||||||
ticker := time.NewTicker(UpdateCheckInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
available, resp := u.checkForUpdate(ctx)
|
||||||
|
if available {
|
||||||
|
err := u.DownloadNewRelease(ctx, resp)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
|
||||||
|
} else {
|
||||||
|
err = cb(resp.UpdateVersion)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
slog.Debug("stopping background update checker")
|
slog.Debug("stopping background update checker")
|
||||||
return
|
return
|
||||||
case <-u.checkNow:
|
default:
|
||||||
// Immediate check triggered
|
time.Sleep(UpdateCheckInterval)
|
||||||
case <-ticker.C:
|
|
||||||
// Regular interval check
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always check for updates
|
|
||||||
available, resp := u.checkForUpdate(ctx)
|
|
||||||
if !available {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update is available - check if auto-update is enabled for downloading
|
|
||||||
settings, err := u.Store.Settings()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to load settings", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !settings.AutoUpdateEnabled {
|
|
||||||
// Auto-update disabled - don't download, just log
|
|
||||||
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auto-update is enabled - download
|
|
||||||
err = u.DownloadNewRelease(ctx, resp)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to download new release", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Download successful - show tray notification (regardless of toggle state)
|
|
||||||
err = cb(resp.UpdateVersion)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("failed to register update available with tray", "error", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) CheckForUpdate(ctx context.Context) (bool, string, error) {
|
|
||||||
available, resp := u.checkForUpdate(ctx)
|
|
||||||
return available, resp.UpdateVersion, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *Updater) InstallAndRestart() error {
|
|
||||||
if !UpdateDownloaded {
|
|
||||||
return fmt.Errorf("no update downloaded")
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("installing update and restarting")
|
|
||||||
return DoUpgrade(true)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -86,17 +85,7 @@ func TestBackgoundChecker(t *testing.T) {
|
|||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
updater := &Updater{Store: &store.Store{}}
|
||||||
defer updater.Store.Close()
|
defer updater.Store.Close() // Ensure database is closed
|
||||||
|
|
||||||
settings, err := updater.Store.Settings()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
settings.AutoUpdateEnabled = true
|
|
||||||
if err := updater.Store.SetSettings(settings); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
select {
|
select {
|
||||||
case <-stallTimer.C:
|
case <-stallTimer.C:
|
||||||
@@ -110,187 +99,3 @@ func TestBackgoundChecker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
|
|
||||||
UpdateStageDir = t.TempDir()
|
|
||||||
var downloadAttempted atomic.Bool
|
|
||||||
done := make(chan struct{})
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(t.Context())
|
|
||||||
defer cancel()
|
|
||||||
UpdateCheckInitialDelay = 5 * time.Millisecond
|
|
||||||
UpdateCheckInterval = 5 * time.Millisecond
|
|
||||||
VerifyDownload = func() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var server *httptest.Server
|
|
||||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/update.json" {
|
|
||||||
w.Write([]byte(
|
|
||||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
|
||||||
server.URL+"/9.9.9/"+Installer)))
|
|
||||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
|
||||||
downloadAttempted.Store(true)
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
zw := zip.NewWriter(buf)
|
|
||||||
zw.Close()
|
|
||||||
io.Copy(w, buf)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
|
||||||
defer updater.Store.Close()
|
|
||||||
|
|
||||||
// Ensure auto-update is disabled
|
|
||||||
settings, err := updater.Store.Settings()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
settings.AutoUpdateEnabled = false
|
|
||||||
if err := updater.Store.SetSettings(settings); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cb := func(ver string) error {
|
|
||||||
t.Fatal("callback should not be called when auto-update is disabled")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
|
||||||
|
|
||||||
// Wait enough time for multiple check cycles
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
close(done)
|
|
||||||
|
|
||||||
if downloadAttempted.Load() {
|
|
||||||
t.Fatal("download should not be attempted when auto-update is disabled")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCancelOngoingDownload(t *testing.T) {
|
|
||||||
UpdateStageDir = t.TempDir()
|
|
||||||
downloadStarted := make(chan struct{})
|
|
||||||
downloadCancelled := make(chan struct{})
|
|
||||||
|
|
||||||
ctx := t.Context()
|
|
||||||
VerifyDownload = func() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var server *httptest.Server
|
|
||||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/update.json" {
|
|
||||||
w.Write([]byte(
|
|
||||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
|
||||||
server.URL+"/9.9.9/"+Installer)))
|
|
||||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
|
||||||
if r.Method == http.MethodHead {
|
|
||||||
w.Header().Set("Content-Length", "1000000")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Signal that download has started
|
|
||||||
close(downloadStarted)
|
|
||||||
// Wait for cancellation or timeout
|
|
||||||
select {
|
|
||||||
case <-r.Context().Done():
|
|
||||||
close(downloadCancelled)
|
|
||||||
return
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Error("download was not cancelled in time")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
|
||||||
defer updater.Store.Close()
|
|
||||||
|
|
||||||
_, resp := updater.checkForUpdate(ctx)
|
|
||||||
|
|
||||||
// Start download in goroutine
|
|
||||||
go func() {
|
|
||||||
_ = updater.DownloadNewRelease(ctx, resp)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for download to start
|
|
||||||
select {
|
|
||||||
case <-downloadStarted:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("download did not start in time")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cancel the download
|
|
||||||
updater.CancelOngoingDownload()
|
|
||||||
|
|
||||||
// Verify cancellation was received
|
|
||||||
select {
|
|
||||||
case <-downloadCancelled:
|
|
||||||
// Success
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("download cancellation was not received by server")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTriggerImmediateCheck(t *testing.T) {
|
|
||||||
UpdateStageDir = t.TempDir()
|
|
||||||
checkCount := atomic.Int32{}
|
|
||||||
checkDone := make(chan struct{}, 10)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(t.Context())
|
|
||||||
defer cancel()
|
|
||||||
// Set a very long interval so only TriggerImmediateCheck causes checks
|
|
||||||
UpdateCheckInitialDelay = 1 * time.Millisecond
|
|
||||||
UpdateCheckInterval = 1 * time.Hour
|
|
||||||
VerifyDownload = func() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/update.json" {
|
|
||||||
checkCount.Add(1)
|
|
||||||
select {
|
|
||||||
case checkDone <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
// Return no update available
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
|
||||||
defer updater.Store.Close()
|
|
||||||
|
|
||||||
cb := func(ver string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
|
||||||
|
|
||||||
// Wait for goroutine to start and pass initial delay
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
|
|
||||||
// With 1 hour interval, no check should have happened yet
|
|
||||||
initialCount := checkCount.Load()
|
|
||||||
|
|
||||||
// Trigger immediate check
|
|
||||||
updater.TriggerImmediateCheck()
|
|
||||||
|
|
||||||
// Wait for the triggered check
|
|
||||||
select {
|
|
||||||
case <-checkDone:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("triggered check did not happen")
|
|
||||||
}
|
|
||||||
|
|
||||||
finalCount := checkCount.Load()
|
|
||||||
if finalCount <= initialCount {
|
|
||||||
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -369,6 +369,25 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
|
||||||
|
// const ERROR_SUCCESS syscall.Errno = 0
|
||||||
|
|
||||||
|
// t.muMenus.RLock()
|
||||||
|
// menu := uintptr(t.menus[parentId])
|
||||||
|
// t.muMenus.RUnlock()
|
||||||
|
// res, _, err := pRemoveMenu.Call(
|
||||||
|
// menu,
|
||||||
|
// uintptr(menuItemId),
|
||||||
|
// MF_BYCOMMAND,
|
||||||
|
// )
|
||||||
|
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// t.delFromVisibleItems(parentId, menuItemId)
|
||||||
|
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
|
||||||
func (t *winTray) showMenu() error {
|
func (t *winTray) showMenu() error {
|
||||||
p := point{}
|
p := point{}
|
||||||
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ const (
|
|||||||
IMAGE_ICON = 1 // Loads an icon
|
IMAGE_ICON = 1 // Loads an icon
|
||||||
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
|
||||||
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
|
||||||
|
MF_BYCOMMAND = 0x00000000
|
||||||
MFS_DISABLED = 0x00000003
|
MFS_DISABLED = 0x00000003
|
||||||
MFT_SEPARATOR = 0x00000800
|
MFT_SEPARATOR = 0x00000800
|
||||||
MFT_STRING = 0x00000000
|
MFT_STRING = 0x00000000
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
Prompt: ">>> ",
|
Prompt: ">>> ",
|
||||||
AltPrompt: "... ",
|
AltPrompt: "... ",
|
||||||
Placeholder: "Send a message (/? for help)",
|
Placeholder: "Send a message (/? for help)",
|
||||||
AltPlaceholder: "Press Enter to send",
|
AltPlaceholder: `Use """ to end multi-line input`,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ ollama pull glm-4.7:cloud
|
|||||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||||
```
|
```
|
||||||
@@ -248,13 +247,12 @@ curl -X POST http://localhost:11434/v1/messages \
|
|||||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||||
```
|
```
|
||||||
|
|
||||||
Or set the environment variables in your shell profile:
|
Or set the environment variables in your shell profile:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=ollama
|
export ANTHROPIC_API_KEY=ollama
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
|||||||
import { Ollama } from "ollama";
|
import { Ollama } from "ollama";
|
||||||
|
|
||||||
const client = new Ollama();
|
const client = new Ollama();
|
||||||
const results = await client.webSearch("what is ollama?");
|
const results = await client.webSearch({ query: "what is ollama?" });
|
||||||
console.log(JSON.stringify(results, null, 2));
|
console.log(JSON.stringify(results, null, 2));
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
|||||||
import { Ollama } from "ollama";
|
import { Ollama } from "ollama";
|
||||||
|
|
||||||
const client = new Ollama();
|
const client = new Ollama();
|
||||||
const fetchResult = await client.webFetch("https://ollama.com");
|
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||||
console.log(JSON.stringify(fetchResult, null, 2));
|
console.log(JSON.stringify(fetchResult, null, 2));
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -111,9 +111,7 @@
|
|||||||
"/integrations/zed",
|
"/integrations/zed",
|
||||||
"/integrations/roo-code",
|
"/integrations/roo-code",
|
||||||
"/integrations/n8n",
|
"/integrations/n8n",
|
||||||
"/integrations/xcode",
|
"/integrations/xcode"
|
||||||
"/integrations/onyx",
|
|
||||||
"/integrations/marimo"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 4096 tokens.
|
By default, Ollama uses a context window size of 2048 tokens.
|
||||||
|
|
||||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||||
|
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 174 KiB |
|
Before Width: | Height: | Size: 80 KiB |
|
Before Width: | Height: | Size: 230 KiB |
|
Before Width: | Height: | Size: 178 KiB |
|
Before Width: | Height: | Size: 186 KiB |
|
Before Width: | Height: | Size: 100 KiB |
|
Before Width: | Height: | Size: 306 KiB |
|
Before Width: | Height: | Size: 300 KiB |
|
Before Width: | Height: | Size: 211 KiB |
@@ -25,7 +25,6 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
|||||||
1. Set the environment variables:
|
1. Set the environment variables:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
|
||||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||||
export ANTHROPIC_API_KEY=ollama
|
export ANTHROPIC_API_KEY=ollama
|
||||||
```
|
```
|
||||||
@@ -39,7 +38,7 @@ claude --model qwen3-coder
|
|||||||
Or run with environment variables inline:
|
Or run with environment variables inline:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||||
```
|
```
|
||||||
|
|
||||||
## Connecting to ollama.com
|
## Connecting to ollama.com
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
---
|
|
||||||
title: marimo
|
|
||||||
---
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
|
||||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
|
||||||
|
|
||||||
```
|
|
||||||
uvx marimo edit --sandbox notebook.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage with Ollama
|
|
||||||
|
|
||||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
|
||||||
you can find and configure Ollama as an AI provider. For local use you
|
|
||||||
would typically point the base url to `http://localhost:11434/v1`.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-settings.png"
|
|
||||||
alt="Ollama settings in marimo"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-models.png"
|
|
||||||
alt="Selecting an Ollama model"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-add-model.png"
|
|
||||||
alt="Adding a new Ollama model"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-chat.png"
|
|
||||||
alt="Configure code completion"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
|
||||||
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/marimo-code-completion.png"
|
|
||||||
alt="Configure code completion"
|
|
||||||
width="50%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
## Connecting to ollama.com
|
|
||||||
|
|
||||||
1. Sign in to ollama cloud via `ollama signin`
|
|
||||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
|
||||||
3. You can now refer to this model in marimo!
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
---
|
|
||||||
title: Onyx
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
|
||||||
- Creating custom Agents
|
|
||||||
- Web search
|
|
||||||
- Deep Research
|
|
||||||
- RAG over uploaded documents and connected apps
|
|
||||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
|
||||||
- MCP and OpenAPI Actions support
|
|
||||||
- Image generation
|
|
||||||
- User/Groups management, RBAC, SSO, etc.
|
|
||||||
|
|
||||||
Onyx can be deployed for single users or large organizations.
|
|
||||||
|
|
||||||
## Install Onyx
|
|
||||||
|
|
||||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
|
||||||
|
|
||||||
<Info>
|
|
||||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
|
||||||
</Info>
|
|
||||||
|
|
||||||
## Usage with Ollama
|
|
||||||
|
|
||||||
1. Login to your Onyx deployment (create an account first).
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/onyx-login.png"
|
|
||||||
alt="Onyx Login Page"
|
|
||||||
width="75%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
2. In the set-up process select `Ollama` as the LLM provider.
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/onyx-ollama-llm.png"
|
|
||||||
alt="Onyx Set Up Form"
|
|
||||||
width="75%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
3. Provide your **Ollama API URL** and select your models.
|
|
||||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/onyx-ollama-form.png"
|
|
||||||
alt="Selecting Ollama Models"
|
|
||||||
width="75%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
|
||||||
|
|
||||||
## Send your first query
|
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
|
||||||
<img
|
|
||||||
src="/images/onyx-query.png"
|
|
||||||
alt="Onyx Query Example"
|
|
||||||
width="75%"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Linux
|
title: "Linux"
|
||||||
---
|
---
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
## Manual install
|
## Manual install
|
||||||
|
|
||||||
<Note>
|
<Note>
|
||||||
If you are upgrading from a prior version, you should remove the old libraries
|
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||||
with `sudo rm -rf /usr/lib/ollama` first.
|
|
||||||
</Note>
|
</Note>
|
||||||
|
|
||||||
Download and extract the package:
|
Download and extract the package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||||
| sudo tar x -C /usr
|
| sudo tar zx -C /usr
|
||||||
```
|
```
|
||||||
|
|
||||||
Start Ollama:
|
Start Ollama:
|
||||||
@@ -41,8 +40,8 @@ ollama -v
|
|||||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||||
| sudo tar x -C /usr
|
| sudo tar zx -C /usr
|
||||||
```
|
```
|
||||||
|
|
||||||
### ARM64 install
|
### ARM64 install
|
||||||
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
|||||||
Download and extract the ARM64-specific package:
|
Download and extract the ARM64-specific package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||||
| sudo tar x -C /usr
|
| sudo tar zx -C /usr
|
||||||
```
|
```
|
||||||
|
|
||||||
### Adding Ollama as a startup service (recommended)
|
### Adding Ollama as a startup service (recommended)
|
||||||
@@ -113,11 +112,7 @@ sudo systemctl status ollama
|
|||||||
```
|
```
|
||||||
|
|
||||||
<Note>
|
<Note>
|
||||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||||
kernel source, the version is older and may not support all ROCm features. We
|
|
||||||
recommend you install the latest driver from
|
|
||||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
|
||||||
GPU.
|
|
||||||
</Note>
|
</Note>
|
||||||
|
|
||||||
## Customizing
|
## Customizing
|
||||||
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
Or by re-downloading Ollama:
|
Or by re-downloading Ollama:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||||
| sudo tar x -C /usr
|
| sudo tar zx -C /usr
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installing specific versions
|
## Installing specific versions
|
||||||
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
|
|||||||
sudo userdel ollama
|
sudo userdel ollama
|
||||||
sudo groupdel ollama
|
sudo groupdel ollama
|
||||||
sudo rm -r /usr/share/ollama
|
sudo rm -r /usr/share/ollama
|
||||||
```
|
```
|
||||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
|||||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
@@ -1464,11 +1464,6 @@ type CompletionRequest struct {
|
|||||||
|
|
||||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||||
TopLogprobs int
|
TopLogprobs int
|
||||||
|
|
||||||
// Image generation fields
|
|
||||||
Width int32 `json:"width,omitempty"`
|
|
||||||
Height int32 `json:"height,omitempty"`
|
|
||||||
Seed int64 `json:"seed,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoneReason represents the reason why a completion response is done
|
// DoneReason represents the reason why a completion response is done
|
||||||
@@ -1517,11 +1512,6 @@ type CompletionResponse struct {
|
|||||||
|
|
||||||
// Logprobs contains log probability information if requested
|
// Logprobs contains log probability information if requested
|
||||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
// Image generation fields
|
|
||||||
Image []byte `json:"image,omitempty"` // Generated image
|
|
||||||
Step int `json:"step,omitempty"` // Current generation step
|
|
||||||
Total int `json:"total,omitempty"` // Total generation steps
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
@@ -442,7 +441,6 @@ type ResponsesWriter struct {
|
|||||||
stream bool
|
stream bool
|
||||||
responseID string
|
responseID string
|
||||||
itemID string
|
itemID string
|
||||||
request openai.ResponsesRequest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||||
@@ -480,9 +478,7 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
|||||||
|
|
||||||
// Non-streaming response
|
// Non-streaming response
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||||
completedAt := time.Now().Unix()
|
|
||||||
response.CompletedAt = &completedAt
|
|
||||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,12 +523,11 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
w := &ResponsesWriter{
|
w := &ResponsesWriter{
|
||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||||
model: req.Model,
|
model: req.Model,
|
||||||
stream: streamRequested,
|
stream: streamRequested,
|
||||||
responseID: responseID,
|
responseID: responseID,
|
||||||
itemID: itemID,
|
itemID: itemID,
|
||||||
request: req,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers based on streaming mode
|
// Set headers based on streaming mode
|
||||||
|
|||||||
@@ -630,10 +630,6 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
|||||||
|
|
||||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||||
func decodeImageURL(url string) (api.ImageData, error) {
|
func decodeImageURL(url string) (api.ImageData, error) {
|
||||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
|
||||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
|
||||||
}
|
|
||||||
|
|
||||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||||
|
|
||||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
@@ -266,9 +265,9 @@ type ResponsesText struct {
|
|||||||
type ResponsesTool struct {
|
type ResponsesTool struct {
|
||||||
Type string `json:"type"` // "function"
|
Type string `json:"type"` // "function"
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description *string `json:"description"` // nullable but required
|
Description string `json:"description,omitempty"`
|
||||||
Strict *bool `json:"strict"` // nullable but required
|
Strict bool `json:"strict,omitempty"`
|
||||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
Parameters map[string]any `json:"parameters,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesRequest struct {
|
type ResponsesRequest struct {
|
||||||
@@ -476,16 +475,11 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var description string
|
|
||||||
if t.Description != nil {
|
|
||||||
description = *t.Description
|
|
||||||
}
|
|
||||||
|
|
||||||
return api.Tool{
|
return api.Tool{
|
||||||
Type: t.Type,
|
Type: t.Type,
|
||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Description: description,
|
Description: t.Description,
|
||||||
Parameters: params,
|
Parameters: params,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
@@ -522,60 +516,17 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
|||||||
|
|
||||||
// Response types for the Responses API
|
// Response types for the Responses API
|
||||||
|
|
||||||
// ResponsesTextField represents the text output configuration in the response.
|
|
||||||
type ResponsesTextField struct {
|
|
||||||
Format ResponsesTextFormat `json:"format"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
|
||||||
type ResponsesReasoningOutput struct {
|
|
||||||
Effort *string `json:"effort,omitempty"`
|
|
||||||
Summary *string `json:"summary,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResponsesError represents an error in the response.
|
|
||||||
type ResponsesError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
|
||||||
type ResponsesIncompleteDetails struct {
|
|
||||||
Reason string `json:"reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponsesResponse struct {
|
type ResponsesResponse struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
CompletedAt *int64 `json:"completed_at"`
|
Status string `json:"status"`
|
||||||
Status string `json:"status"`
|
Model string `json:"model"`
|
||||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
Output []ResponsesOutputItem `json:"output"`
|
||||||
Model string `json:"model"`
|
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||||
PreviousResponseID *string `json:"previous_response_id"`
|
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||||
Instructions *string `json:"instructions"`
|
// requires additional plumbing to find the effective values since the
|
||||||
Output []ResponsesOutputItem `json:"output"`
|
// defaults can come from the model or the request
|
||||||
Error *ResponsesError `json:"error"`
|
|
||||||
Tools []ResponsesTool `json:"tools"`
|
|
||||||
ToolChoice any `json:"tool_choice"`
|
|
||||||
Truncation string `json:"truncation"`
|
|
||||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
|
||||||
Text ResponsesTextField `json:"text"`
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
PresencePenalty float64 `json:"presence_penalty"`
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
||||||
TopLogprobs int `json:"top_logprobs"`
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
|
||||||
Usage *ResponsesUsage `json:"usage"`
|
|
||||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
|
||||||
MaxToolCalls *int `json:"max_tool_calls"`
|
|
||||||
Store bool `json:"store"`
|
|
||||||
Background bool `json:"background"`
|
|
||||||
ServiceTier string `json:"service_tier"`
|
|
||||||
Metadata map[string]any `json:"metadata"`
|
|
||||||
SafetyIdentifier *string `json:"safety_identifier"`
|
|
||||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesOutputItem struct {
|
type ResponsesOutputItem struct {
|
||||||
@@ -599,39 +550,18 @@ type ResponsesReasoningSummary struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesOutputContent struct {
|
type ResponsesOutputContent struct {
|
||||||
Type string `json:"type"` // "output_text"
|
Type string `json:"type"` // "output_text"
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
Annotations []any `json:"annotations"`
|
|
||||||
Logprobs []any `json:"logprobs"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponsesInputTokensDetails struct {
|
|
||||||
CachedTokens int `json:"cached_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponsesOutputTokensDetails struct {
|
|
||||||
ReasoningTokens int `json:"reasoning_tokens"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesUsage struct {
|
type ResponsesUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
|
||||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||||
func derefFloat64(p *float64, def float64) float64 {
|
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||||
if p != nil {
|
|
||||||
return *p
|
|
||||||
}
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
|
||||||
// The request is used to echo back request parameters in the response.
|
|
||||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
|
||||||
var output []ResponsesOutputItem
|
var output []ResponsesOutputItem
|
||||||
|
|
||||||
// Add reasoning item if thinking is present
|
// Add reasoning item if thinking is present
|
||||||
@@ -655,7 +585,6 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
|
|||||||
output = append(output, ResponsesOutputItem{
|
output = append(output, ResponsesOutputItem{
|
||||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||||
Type: "function_call",
|
Type: "function_call",
|
||||||
Status: "completed",
|
|
||||||
CallID: tc.ID,
|
CallID: tc.ID,
|
||||||
Name: tc.Function.Name,
|
Name: tc.Function.Name,
|
||||||
Arguments: tc.Function.Arguments,
|
Arguments: tc.Function.Arguments,
|
||||||
@@ -669,90 +598,25 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
|
|||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []ResponsesOutputContent{
|
Content: []ResponsesOutputContent{
|
||||||
{
|
{
|
||||||
Type: "output_text",
|
Type: "output_text",
|
||||||
Text: chatResponse.Message.Content,
|
Text: chatResponse.Message.Content,
|
||||||
Annotations: []any{},
|
|
||||||
Logprobs: []any{},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var instructions *string
|
|
||||||
if request.Instructions != "" {
|
|
||||||
instructions = &request.Instructions
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build truncation with default
|
|
||||||
truncation := "disabled"
|
|
||||||
if request.Truncation != nil {
|
|
||||||
truncation = *request.Truncation
|
|
||||||
}
|
|
||||||
|
|
||||||
tools := request.Tools
|
|
||||||
if tools == nil {
|
|
||||||
tools = []ResponsesTool{}
|
|
||||||
}
|
|
||||||
|
|
||||||
text := ResponsesTextField{
|
|
||||||
Format: ResponsesTextFormat{Type: "text"},
|
|
||||||
}
|
|
||||||
if request.Text != nil && request.Text.Format != nil {
|
|
||||||
text.Format = *request.Text.Format
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build reasoning output from request
|
|
||||||
var reasoning *ResponsesReasoningOutput
|
|
||||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
|
||||||
reasoning = &ResponsesReasoningOutput{}
|
|
||||||
if request.Reasoning.Effort != "" {
|
|
||||||
reasoning.Effort = &request.Reasoning.Effort
|
|
||||||
}
|
|
||||||
if request.Reasoning.Summary != "" {
|
|
||||||
reasoning.Summary = &request.Reasoning.Summary
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ResponsesResponse{
|
return ResponsesResponse{
|
||||||
ID: responseID,
|
ID: responseID,
|
||||||
Object: "response",
|
Object: "response",
|
||||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||||
CompletedAt: nil, // Set by middleware when writing final response
|
Status: "completed",
|
||||||
Status: "completed",
|
Model: model,
|
||||||
IncompleteDetails: nil, // Only populated if response incomplete
|
Output: output,
|
||||||
Model: model,
|
|
||||||
PreviousResponseID: nil, // Not supported
|
|
||||||
Instructions: instructions,
|
|
||||||
Output: output,
|
|
||||||
Error: nil, // Only populated on failure
|
|
||||||
Tools: tools,
|
|
||||||
ToolChoice: "auto", // Default value
|
|
||||||
Truncation: truncation,
|
|
||||||
ParallelToolCalls: true, // Default value
|
|
||||||
Text: text,
|
|
||||||
TopP: derefFloat64(request.TopP, 1.0),
|
|
||||||
PresencePenalty: 0, // Default value
|
|
||||||
FrequencyPenalty: 0, // Default value
|
|
||||||
TopLogprobs: 0, // Default value
|
|
||||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
|
||||||
Reasoning: reasoning,
|
|
||||||
Usage: &ResponsesUsage{
|
Usage: &ResponsesUsage{
|
||||||
InputTokens: chatResponse.PromptEvalCount,
|
InputTokens: chatResponse.PromptEvalCount,
|
||||||
OutputTokens: chatResponse.EvalCount,
|
OutputTokens: chatResponse.EvalCount,
|
||||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||||
// TODO(drifkin): wire through the actual values
|
|
||||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
|
||||||
// TODO(drifkin): wire through the actual values
|
|
||||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
|
||||||
},
|
},
|
||||||
MaxOutputTokens: request.MaxOutputTokens,
|
|
||||||
MaxToolCalls: nil, // Not supported
|
|
||||||
Store: false, // We don't store responses
|
|
||||||
Background: request.Background,
|
|
||||||
ServiceTier: "default", // Default value
|
|
||||||
Metadata: map[string]any{},
|
|
||||||
SafetyIdentifier: nil, // Not supported
|
|
||||||
PromptCacheKey: nil, // Not supported
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -772,7 +636,6 @@ type ResponsesStreamConverter struct {
|
|||||||
responseID string
|
responseID string
|
||||||
itemID string
|
itemID string
|
||||||
model string
|
model string
|
||||||
request ResponsesRequest
|
|
||||||
|
|
||||||
// State tracking (mutated across Process calls)
|
// State tracking (mutated across Process calls)
|
||||||
firstWrite bool
|
firstWrite bool
|
||||||
@@ -805,12 +668,11 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||||
return &ResponsesStreamConverter{
|
return &ResponsesStreamConverter{
|
||||||
responseID: responseID,
|
responseID: responseID,
|
||||||
itemID: itemID,
|
itemID: itemID,
|
||||||
model: model,
|
model: model,
|
||||||
request: request,
|
|
||||||
firstWrite: true,
|
firstWrite: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -855,120 +717,25 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
|||||||
return events
|
return events
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
|
||||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
|
||||||
var instructions any = nil
|
|
||||||
if c.request.Instructions != "" {
|
|
||||||
instructions = c.request.Instructions
|
|
||||||
}
|
|
||||||
|
|
||||||
truncation := "disabled"
|
|
||||||
if c.request.Truncation != nil {
|
|
||||||
truncation = *c.request.Truncation
|
|
||||||
}
|
|
||||||
|
|
||||||
var tools []any
|
|
||||||
if c.request.Tools != nil {
|
|
||||||
for _, t := range c.request.Tools {
|
|
||||||
tools = append(tools, map[string]any{
|
|
||||||
"type": t.Type,
|
|
||||||
"name": t.Name,
|
|
||||||
"description": t.Description,
|
|
||||||
"strict": t.Strict,
|
|
||||||
"parameters": t.Parameters,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if tools == nil {
|
|
||||||
tools = []any{}
|
|
||||||
}
|
|
||||||
|
|
||||||
textFormat := map[string]any{"type": "text"}
|
|
||||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
|
||||||
textFormat = map[string]any{
|
|
||||||
"type": c.request.Text.Format.Type,
|
|
||||||
}
|
|
||||||
if c.request.Text.Format.Name != "" {
|
|
||||||
textFormat["name"] = c.request.Text.Format.Name
|
|
||||||
}
|
|
||||||
if c.request.Text.Format.Schema != nil {
|
|
||||||
textFormat["schema"] = c.request.Text.Format.Schema
|
|
||||||
}
|
|
||||||
if c.request.Text.Format.Strict != nil {
|
|
||||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var reasoning any = nil
|
|
||||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
|
||||||
r := map[string]any{}
|
|
||||||
if c.request.Reasoning.Effort != "" {
|
|
||||||
r["effort"] = c.request.Reasoning.Effort
|
|
||||||
} else {
|
|
||||||
r["effort"] = nil
|
|
||||||
}
|
|
||||||
if c.request.Reasoning.Summary != "" {
|
|
||||||
r["summary"] = c.request.Reasoning.Summary
|
|
||||||
} else {
|
|
||||||
r["summary"] = nil
|
|
||||||
}
|
|
||||||
reasoning = r
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build top_p and temperature with defaults
|
|
||||||
topP := 1.0
|
|
||||||
if c.request.TopP != nil {
|
|
||||||
topP = *c.request.TopP
|
|
||||||
}
|
|
||||||
temperature := 1.0
|
|
||||||
if c.request.Temperature != nil {
|
|
||||||
temperature = *c.request.Temperature
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[string]any{
|
|
||||||
"id": c.responseID,
|
|
||||||
"object": "response",
|
|
||||||
"created_at": time.Now().Unix(),
|
|
||||||
"completed_at": nil,
|
|
||||||
"status": status,
|
|
||||||
"incomplete_details": nil,
|
|
||||||
"model": c.model,
|
|
||||||
"previous_response_id": nil,
|
|
||||||
"instructions": instructions,
|
|
||||||
"output": output,
|
|
||||||
"error": nil,
|
|
||||||
"tools": tools,
|
|
||||||
"tool_choice": "auto",
|
|
||||||
"truncation": truncation,
|
|
||||||
"parallel_tool_calls": true,
|
|
||||||
"text": map[string]any{"format": textFormat},
|
|
||||||
"top_p": topP,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"top_logprobs": 0,
|
|
||||||
"temperature": temperature,
|
|
||||||
"reasoning": reasoning,
|
|
||||||
"usage": usage,
|
|
||||||
"max_output_tokens": c.request.MaxOutputTokens,
|
|
||||||
"max_tool_calls": nil,
|
|
||||||
"store": false,
|
|
||||||
"background": c.request.Background,
|
|
||||||
"service_tier": "default",
|
|
||||||
"metadata": map[string]any{},
|
|
||||||
"safety_identifier": nil,
|
|
||||||
"prompt_cache_key": nil,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||||
return c.newEvent("response.created", map[string]any{
|
return c.newEvent("response.created", map[string]any{
|
||||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
"response": map[string]any{
|
||||||
|
"id": c.responseID,
|
||||||
|
"object": "response",
|
||||||
|
"status": "in_progress",
|
||||||
|
"output": []any{},
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||||
return c.newEvent("response.in_progress", map[string]any{
|
return c.newEvent("response.in_progress", map[string]any{
|
||||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
"response": map[string]any{
|
||||||
|
"id": c.responseID,
|
||||||
|
"object": "response",
|
||||||
|
"status": "in_progress",
|
||||||
|
"output": []any{},
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -995,10 +762,9 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
|||||||
|
|
||||||
// Emit delta
|
// Emit delta
|
||||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||||
"item_id": c.reasoningItemID,
|
"item_id": c.reasoningItemID,
|
||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"summary_index": 0,
|
"delta": thinking,
|
||||||
"delta": thinking,
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// TODO(drifkin): consider adding
|
// TODO(drifkin): consider adding
|
||||||
@@ -1017,10 +783,9 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
|||||||
|
|
||||||
events := []ResponsesStreamEvent{
|
events := []ResponsesStreamEvent{
|
||||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||||
"item_id": c.reasoningItemID,
|
"item_id": c.reasoningItemID,
|
||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"summary_index": 0,
|
"text": c.accumulatedThinking,
|
||||||
"text": c.accumulatedThinking,
|
|
||||||
}),
|
}),
|
||||||
c.newEvent("response.output_item.done", map[string]any{
|
c.newEvent("response.output_item.done", map[string]any{
|
||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
@@ -1133,10 +898,8 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
|||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"content_index": c.contentIndex,
|
"content_index": c.contentIndex,
|
||||||
"part": map[string]any{
|
"part": map[string]any{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": "",
|
"text": "",
|
||||||
"annotations": []any{},
|
|
||||||
"logprobs": []any{},
|
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@@ -1150,7 +913,6 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
|||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"content_index": 0,
|
"content_index": 0,
|
||||||
"delta": content,
|
"delta": content,
|
||||||
"logprobs": []any{},
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return events
|
return events
|
||||||
@@ -1182,10 +944,8 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
|||||||
"status": "completed",
|
"status": "completed",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": []map[string]any{{
|
"content": []map[string]any{{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": c.accumulatedText,
|
"text": c.accumulatedText,
|
||||||
"annotations": []any{},
|
|
||||||
"logprobs": []any{},
|
|
||||||
}},
|
}},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1207,7 +967,6 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
|||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"content_index": 0,
|
"content_index": 0,
|
||||||
"text": c.accumulatedText,
|
"text": c.accumulatedText,
|
||||||
"logprobs": []any{},
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// response.content_part.done
|
// response.content_part.done
|
||||||
@@ -1216,10 +975,8 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
|||||||
"output_index": c.outputIndex,
|
"output_index": c.outputIndex,
|
||||||
"content_index": 0,
|
"content_index": 0,
|
||||||
"part": map[string]any{
|
"part": map[string]any{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": c.accumulatedText,
|
"text": c.accumulatedText,
|
||||||
"annotations": []any{},
|
|
||||||
"logprobs": []any{},
|
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -1232,31 +989,26 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
|||||||
"status": "completed",
|
"status": "completed",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": []map[string]any{{
|
"content": []map[string]any{{
|
||||||
"type": "output_text",
|
"type": "output_text",
|
||||||
"text": c.accumulatedText,
|
"text": c.accumulatedText,
|
||||||
"annotations": []any{},
|
|
||||||
"logprobs": []any{},
|
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// response.completed
|
// response.completed
|
||||||
usage := map[string]any{
|
|
||||||
"input_tokens": r.PromptEvalCount,
|
|
||||||
"output_tokens": r.EvalCount,
|
|
||||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
|
||||||
"input_tokens_details": map[string]any{
|
|
||||||
"cached_tokens": 0,
|
|
||||||
},
|
|
||||||
"output_tokens_details": map[string]any{
|
|
||||||
"reasoning_tokens": 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
|
||||||
response["completed_at"] = time.Now().Unix()
|
|
||||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||||
"response": response,
|
"response": map[string]any{
|
||||||
|
"id": c.responseID,
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": c.buildFinalOutput(),
|
||||||
|
"usage": map[string]any{
|
||||||
|
"input_tokens": r.PromptEvalCount,
|
||||||
|
"output_tokens": r.EvalCount,
|
||||||
|
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||||
|
},
|
||||||
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
|||||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
// First chunk with content
|
// First chunk with content
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
Message: api.Message{
|
Message: api.Message{
|
||||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
// First chunk with thinking
|
// First chunk with thinking
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
|||||||
Content: "The answer is 42",
|
Content: "The answer is 42",
|
||||||
},
|
},
|
||||||
Done: true,
|
Done: true,
|
||||||
}, ResponsesRequest{})
|
})
|
||||||
|
|
||||||
// Should have 2 output items: reasoning + message
|
// Should have 2 output items: reasoning + message
|
||||||
if len(response.Output) != 2 {
|
if len(response.Output) != 2 {
|
||||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||||
// Verify that response.output_item.done includes content field for messages
|
// Verify that response.output_item.done includes content field for messages
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
// First chunk
|
// First chunk
|
||||||
converter.Process(api.ChatResponse{
|
converter.Process(api.ChatResponse{
|
||||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||||
// Verify that response.completed includes the output array
|
// Verify that response.completed includes the output array
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
// Process some content
|
// Process some content
|
||||||
converter.Process(api.ChatResponse{
|
converter.Process(api.ChatResponse{
|
||||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||||
// Verify that response.created includes an empty output array
|
// Verify that response.created includes an empty output array
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
Message: api.Message{Content: "Hi"},
|
Message: api.Message{Content: "Hi"},
|
||||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||||
// Verify that events include incrementing sequence numbers
|
// Verify that events include incrementing sequence numbers
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
Message: api.Message{Content: "Hello"},
|
Message: api.Message{Content: "Hello"},
|
||||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
|||||||
|
|
||||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||||
// Verify that function call items include status field
|
// Verify that function call items include status field
|
||||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||||
|
|
||||||
events := converter.Process(api.ChatResponse{
|
events := converter.Process(api.ChatResponse{
|
||||||
Message: api.Message{
|
Message: api.Message{
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Prompt struct {
|
type Prompt struct {
|
||||||
@@ -37,11 +36,10 @@ type Terminal struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Instance struct {
|
type Instance struct {
|
||||||
Prompt *Prompt
|
Prompt *Prompt
|
||||||
Terminal *Terminal
|
Terminal *Terminal
|
||||||
History *History
|
History *History
|
||||||
Pasting bool
|
Pasting bool
|
||||||
pastedLines []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(prompt Prompt) (*Instance, error) {
|
func New(prompt Prompt) (*Instance, error) {
|
||||||
@@ -176,8 +174,6 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
case CharEsc:
|
case CharEsc:
|
||||||
esc = true
|
esc = true
|
||||||
case CharInterrupt:
|
case CharInterrupt:
|
||||||
i.pastedLines = nil
|
|
||||||
i.Prompt.UseAlt = false
|
|
||||||
return "", ErrInterrupt
|
return "", ErrInterrupt
|
||||||
case CharPrev:
|
case CharPrev:
|
||||||
i.historyPrev(buf, ¤tLineBuf)
|
i.historyPrev(buf, ¤tLineBuf)
|
||||||
@@ -192,23 +188,7 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
case CharForward:
|
case CharForward:
|
||||||
buf.MoveRight()
|
buf.MoveRight()
|
||||||
case CharBackspace, CharCtrlH:
|
case CharBackspace, CharCtrlH:
|
||||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
buf.Remove()
|
||||||
lastIdx := len(i.pastedLines) - 1
|
|
||||||
prevLine := i.pastedLines[lastIdx]
|
|
||||||
i.pastedLines = i.pastedLines[:lastIdx]
|
|
||||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
|
||||||
if len(i.pastedLines) == 0 {
|
|
||||||
fmt.Print(i.Prompt.Prompt)
|
|
||||||
i.Prompt.UseAlt = false
|
|
||||||
} else {
|
|
||||||
fmt.Print(i.Prompt.AltPrompt)
|
|
||||||
}
|
|
||||||
for _, r := range prevLine {
|
|
||||||
buf.Add(r)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
buf.Remove()
|
|
||||||
}
|
|
||||||
case CharTab:
|
case CharTab:
|
||||||
// todo: convert back to real tabs
|
// todo: convert back to real tabs
|
||||||
for range 8 {
|
for range 8 {
|
||||||
@@ -231,28 +211,13 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
case CharCtrlZ:
|
case CharCtrlZ:
|
||||||
fd := os.Stdin.Fd()
|
fd := os.Stdin.Fd()
|
||||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||||
case CharCtrlJ:
|
case CharEnter, CharCtrlJ:
|
||||||
i.pastedLines = append(i.pastedLines, buf.String())
|
|
||||||
buf.Buf.Clear()
|
|
||||||
buf.Pos = 0
|
|
||||||
buf.DisplayPos = 0
|
|
||||||
buf.LineHasSpace.Clear()
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Print(i.Prompt.AltPrompt)
|
|
||||||
i.Prompt.UseAlt = true
|
|
||||||
continue
|
|
||||||
case CharEnter:
|
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
if len(i.pastedLines) > 0 {
|
|
||||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
|
||||||
i.pastedLines = nil
|
|
||||||
}
|
|
||||||
if output != "" {
|
if output != "" {
|
||||||
i.History.Add(output)
|
i.History.Add(output)
|
||||||
}
|
}
|
||||||
buf.MoveToEnd()
|
buf.MoveToEnd()
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
i.Prompt.UseAlt = false
|
|
||||||
|
|
||||||
return output, nil
|
return output, nil
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ _build_macapp() {
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
rm -f dist/Ollama-darwin.zip
|
rm -f dist/Ollama-darwin.zip
|
||||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||||
|
|
||||||
# Notarize and Staple
|
# Notarize and Staple
|
||||||
@@ -187,7 +187,7 @@ _build_macapp() {
|
|||||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||||
rm -f dist/Ollama-darwin.zip
|
rm -f dist/Ollama-darwin.zip
|
||||||
$(xcrun -f stapler) staple dist/Ollama.app
|
$(xcrun -f stapler) staple dist/Ollama.app
|
||||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||||
|
|
||||||
rm -f dist/Ollama.dmg
|
rm -f dist/Ollama.dmg
|
||||||
|
|
||||||
|
|||||||
@@ -50,17 +50,12 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
|||||||
return redirectURL, nil
|
return redirectURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||||
redirectURL, err := challenge.URL()
|
redirectURL, err := challenge.URL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
|
||||||
if redirectURL.Host != originalHost {
|
|
||||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
|
||||||
}
|
|
||||||
|
|
||||||
sha256sum := sha256.Sum256(nil)
|
sha256sum := sha256.Sum256(nil)
|
||||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||||
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
realm string
|
|
||||||
originalHost string
|
|
||||||
wantMismatch bool
|
|
||||||
}{
|
|
||||||
{"https://example.com/token", "example.com", false},
|
|
||||||
{"https://example.com/token", "other.com", true},
|
|
||||||
{"https://example.com/token", "localhost:8000", true},
|
|
||||||
{"https://localhost:5000/token", "localhost:5000", false},
|
|
||||||
{"https://localhost:5000/token", "localhost:6000", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.originalHost, func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
|
||||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
|
||||||
|
|
||||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
|
||||||
if tt.wantMismatch && !isMismatch {
|
|
||||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
|
||||||
}
|
|
||||||
if !tt.wantMismatch && isMismatch {
|
|
||||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRegistryChallenge(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
wantRealm, wantService, wantScope string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
|
||||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
|
||||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
|
||||||
},
|
|
||||||
{"", "", "", ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
result := parseRegistryChallenge(tt.input)
|
|
||||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
|
||||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
|
||||||
tt.input, result.Realm, result.Service, result.Scope,
|
|
||||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryChallengeURL(t *testing.T) {
|
|
||||||
challenge := registryChallenge{
|
|
||||||
Realm: "https://auth.example.com/token",
|
|
||||||
Service: "registry",
|
|
||||||
Scope: "repo:foo:pull repo:bar:push",
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := challenge.URL()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("URL() error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.Host != "auth.example.com" {
|
|
||||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
|
||||||
}
|
|
||||||
if u.Path != "/token" {
|
|
||||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
|
||||||
}
|
|
||||||
|
|
||||||
q := u.Query()
|
|
||||||
if q.Get("service") != "registry" {
|
|
||||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
|
||||||
}
|
|
||||||
if scopes := q["scope"]; len(scopes) != 2 {
|
|
||||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
|
||||||
}
|
|
||||||
if q.Get("ts") == "" {
|
|
||||||
t.Error("missing ts")
|
|
||||||
}
|
|
||||||
if q.Get("nonce") == "" {
|
|
||||||
t.Error("missing nonce")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Nonces should differ between calls
|
|
||||||
u2, _ := challenge.URL()
|
|
||||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
|
||||||
t.Error("nonce should be unique per call")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
|
||||||
challenge := registryChallenge{Realm: "://invalid"}
|
|
||||||
if _, err := challenge.URL(); err == nil {
|
|
||||||
t.Error("expected error for invalid URL")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -95,11 +95,48 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
numDownloadParts = 16
|
// numDownloadParts is the default number of concurrent download parts for standard downloads
|
||||||
|
numDownloadParts = 16
|
||||||
|
// numHFDownloadParts is the reduced number of concurrent download parts for HuggingFace
|
||||||
|
// downloads to avoid triggering rate limits (HTTP 429 errors). See GitHub issue #13297.
|
||||||
|
numHFDownloadParts = 4
|
||||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isHuggingFaceURL returns true if the URL is from a HuggingFace domain.
|
||||||
|
// This includes:
|
||||||
|
// - huggingface.co (main domain)
|
||||||
|
// - *.huggingface.co (subdomains like cdn-lfs.huggingface.co)
|
||||||
|
// - hf.co (shortlink domain)
|
||||||
|
// - *.hf.co (CDN domains like cdn-lfs.hf.co, cdn-lfs3.hf.co)
|
||||||
|
func isHuggingFaceURL(u *url.URL) bool {
|
||||||
|
if u == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
host := strings.ToLower(u.Hostname())
|
||||||
|
return host == "huggingface.co" ||
|
||||||
|
strings.HasSuffix(host, ".huggingface.co") ||
|
||||||
|
host == "hf.co" ||
|
||||||
|
strings.HasSuffix(host, ".hf.co")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNumDownloadParts returns the number of concurrent download parts to use
|
||||||
|
// for the given URL. HuggingFace URLs use reduced concurrency (default 4) to
|
||||||
|
// avoid triggering rate limits. This can be overridden via the OLLAMA_HF_CONCURRENCY
|
||||||
|
// environment variable. For non-HuggingFace URLs, returns the standard concurrency (16).
|
||||||
|
func getNumDownloadParts(u *url.URL) int {
|
||||||
|
if isHuggingFaceURL(u) {
|
||||||
|
if v := os.Getenv("OLLAMA_HF_CONCURRENCY"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return numHFDownloadParts
|
||||||
|
}
|
||||||
|
return numDownloadParts
|
||||||
|
}
|
||||||
|
|
||||||
func (p *blobDownloadPart) Name() string {
|
func (p *blobDownloadPart) Name() string {
|
||||||
return strings.Join([]string{
|
return strings.Join([]string{
|
||||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||||
@@ -271,7 +308,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
}
|
}
|
||||||
|
|
||||||
g, inner := errgroup.WithContext(ctx)
|
g, inner := errgroup.WithContext(ctx)
|
||||||
g.SetLimit(numDownloadParts)
|
concurrency := getNumDownloadParts(directURL)
|
||||||
|
if concurrency != numDownloadParts {
|
||||||
|
slog.Info(fmt.Sprintf("using reduced concurrency (%d) for HuggingFace download", concurrency))
|
||||||
|
}
|
||||||
|
g.SetLimit(concurrency)
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
part := b.Parts[i]
|
part := b.Parts[i]
|
||||||
if part.Completed.Load() == part.Size {
|
if part.Completed.Load() == part.Size {
|
||||||
|
|||||||
194
server/download_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsHuggingFaceURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil url",
|
||||||
|
url: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface.co main domain",
|
||||||
|
url: "https://huggingface.co/some/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cdn-lfs.huggingface.co subdomain",
|
||||||
|
url: "https://cdn-lfs.huggingface.co/repos/abc/123",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cdn-lfs3.hf.co CDN domain",
|
||||||
|
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hf.co shortlink domain",
|
||||||
|
url: "https://hf.co/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase HuggingFace domain",
|
||||||
|
url: "https://HUGGINGFACE.CO/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case HF domain",
|
||||||
|
url: "https://Cdn-Lfs.HF.Co/repos",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ollama registry",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "github.com",
|
||||||
|
url: "https://github.com/ollama/ollama",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fake huggingface domain",
|
||||||
|
url: "https://nothuggingface.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fake hf domain",
|
||||||
|
url: "https://nothf.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface in path not host",
|
||||||
|
url: "https://example.com/huggingface.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var u *url.URL
|
||||||
|
if tc.url != "" {
|
||||||
|
var err error
|
||||||
|
u, err = url.Parse(tc.url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
got := isHuggingFaceURL(u)
|
||||||
|
assert.Equal(t, tc.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNumDownloadParts(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
envValue string
|
||||||
|
expected int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil url returns default",
|
||||||
|
url: "",
|
||||||
|
envValue: "",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "nil URL should return standard concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ollama registry returns default",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
envValue: "",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "Ollama registry should use standard concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface returns reduced default",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "HuggingFace should use reduced concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hf.co CDN returns reduced default",
|
||||||
|
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
||||||
|
envValue: "",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "HuggingFace CDN should use reduced concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with env override",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "2",
|
||||||
|
expected: 2,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY should override default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with higher env override",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "8",
|
||||||
|
expected: 8,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY can be set higher than default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (non-numeric)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "invalid",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Invalid OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (zero)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "0",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Zero OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (negative)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "-1",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Negative OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-huggingface ignores env",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
envValue: "2",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY should not affect non-HF URLs",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Set or clear the environment variable
|
||||||
|
if tc.envValue != "" {
|
||||||
|
t.Setenv("OLLAMA_HF_CONCURRENCY", tc.envValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
var u *url.URL
|
||||||
|
if tc.url != "" {
|
||||||
|
var err error
|
||||||
|
u, err = url.Parse(tc.url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
got := getNumDownloadParts(u)
|
||||||
|
assert.Equal(t, tc.expected, got, tc.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -775,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
Realm: challenge.Realm,
|
Realm: challenge.Realm,
|
||||||
Service: challenge.Service,
|
Service: challenge.Service,
|
||||||
Scope: challenge.Scope,
|
Scope: challenge.Scope,
|
||||||
}, base.Host)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||||
@@ -850,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
|||||||
Realm: challenge.Realm,
|
Realm: challenge.Realm,
|
||||||
Service: challenge.Service,
|
Service: challenge.Service,
|
||||||
Scope: challenge.Scope,
|
Scope: challenge.Scope,
|
||||||
}, base.Host)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||||
@@ -916,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
|
|
||||||
// Handle authentication error with one retry
|
// Handle authentication error with one retry
|
||||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
token, err := getAuthorizationToken(ctx, challenge)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ import (
|
|||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||||
@@ -163,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
|||||||
return runner.llama, model, &opts, nil
|
return runner.llama, model, &opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||||
|
// This implements the imagegenapi.RunnerScheduler interface.
|
||||||
|
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||||
|
m := &Model{
|
||||||
|
Name: modelName,
|
||||||
|
ShortName: modelName,
|
||||||
|
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
Capabilities: []string{"image"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||||
|
var runner *runnerRef
|
||||||
|
select {
|
||||||
|
case runner = <-runnerCh:
|
||||||
|
case err := <-errCh:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return runner.llama, nil
|
||||||
|
}
|
||||||
|
|
||||||
func signinURL() (string, error) {
|
func signinURL() (string, error) {
|
||||||
pubKey, err := auth.GetPublicKey()
|
pubKey, err := auth.GetPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -190,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is a known image generation model
|
||||||
|
if imagegen.ResolveModelName(req.Model) != "" {
|
||||||
|
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
name := model.ParseName(req.Model)
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
// Ideally this is "invalid model name" but we're keeping with
|
// Ideally this is "invalid model name" but we're keeping with
|
||||||
@@ -1557,12 +1587,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||||
// Experimental OpenAI-compatible image generation endpoint
|
|
||||||
r.POST("/v1/images/generations", s.handleImageGeneration)
|
|
||||||
|
|
||||||
// Inference (Anthropic compatibility)
|
// Inference (Anthropic compatibility)
|
||||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||||
|
|
||||||
|
// Experimental image generation support
|
||||||
|
imagegenapi.RegisterRoutes(r, s)
|
||||||
|
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
rs := ®istry.Local{
|
rs := ®istry.Local{
|
||||||
@@ -1880,62 +1911,6 @@ func toolCallId() string {
|
|||||||
return "call_" + strings.ToLower(string(b))
|
return "call_" + strings.ToLower(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleImageGeneration(c *gin.Context) {
|
|
||||||
var req struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Size string `json:"size"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := GetModel(req.Model)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
|
|
||||||
var runner *runnerRef
|
|
||||||
select {
|
|
||||||
case runner = <-runnerCh:
|
|
||||||
case err := <-errCh:
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse size (e.g., "1024x768") into width and height
|
|
||||||
width, height := int32(1024), int32(1024)
|
|
||||||
if req.Size != "" {
|
|
||||||
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var image []byte
|
|
||||||
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
||||||
Prompt: req.Prompt,
|
|
||||||
Width: width,
|
|
||||||
Height: height,
|
|
||||||
}, func(resp llm.CompletionResponse) {
|
|
||||||
if len(resp.Image) > 0 {
|
|
||||||
image = resp.Image
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"created": time.Now().Unix(),
|
|
||||||
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@@ -805,8 +807,32 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
|||||||
func (s *mockLlm) HasExited() bool { return false }
|
func (s *mockLlm) HasExited() bool { return false }
|
||||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||||
|
|
||||||
|
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||||
|
// are correctly identified and routed differently from language models.
|
||||||
|
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||||
|
// Model with image capability should be detected
|
||||||
|
imageModel := &Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
Capabilities: []string{"image"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||||
|
|
||||||
|
// Model without image capability should not be detected
|
||||||
|
langModel := &Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
Capabilities: []string{"completion"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||||
|
|
||||||
|
// Empty capabilities should not match
|
||||||
|
emptyModel := &Model{}
|
||||||
|
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||||
|
}
|
||||||
|
|
||||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||||
// loaded in the scheduler can be evicted when idle.
|
// loaded in the scheduler can be evicted by a language model request.
|
||||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
@@ -838,59 +864,3 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
|||||||
require.NotNil(t, runner)
|
require.NotNil(t, runner)
|
||||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
|
||||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
|
||||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
|
||||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
s := InitScheduler(ctx)
|
|
||||||
s.getGpuFn = getGpuFn
|
|
||||||
s.getSystemInfoFn = getSystemInfoFn
|
|
||||||
|
|
||||||
// Load both an imagegen runner and a language model runner
|
|
||||||
imageGenRunner := &runnerRef{
|
|
||||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
|
||||||
modelPath: "/fake/flux/model",
|
|
||||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
|
||||||
sessionDuration: 10 * time.Millisecond,
|
|
||||||
numParallel: 1,
|
|
||||||
refCount: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
langModelRunner := &runnerRef{
|
|
||||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
|
||||||
modelPath: "/fake/llama3/model",
|
|
||||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
|
||||||
sessionDuration: 10 * time.Millisecond,
|
|
||||||
numParallel: 1,
|
|
||||||
refCount: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
|
||||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
|
||||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
|
||||||
s.loadedMu.Unlock()
|
|
||||||
|
|
||||||
// Verify both are loaded
|
|
||||||
s.loadedMu.Lock()
|
|
||||||
require.Len(t, s.loaded, 2)
|
|
||||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
|
||||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
|
||||||
s.loadedMu.Unlock()
|
|
||||||
|
|
||||||
// Verify updateFreeSpace accounts for both
|
|
||||||
gpus := []ml.DeviceInfo{
|
|
||||||
{
|
|
||||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
|
||||||
TotalMemory: 24 * format.GigaByte,
|
|
||||||
FreeMemory: 24 * format.GigaByte,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
s.updateFreeSpace(gpus)
|
|
||||||
|
|
||||||
// Free memory should be reduced by both models
|
|
||||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
|
||||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
|||||||
case resp.StatusCode == http.StatusUnauthorized:
|
case resp.StatusCode == http.StatusUnauthorized:
|
||||||
w.Rollback()
|
w.Rollback()
|
||||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
token, err := getAuthorizationToken(ctx, challenge)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
50
x/README.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Experimental Features
|
||||||
|
|
||||||
|
## MLX Backend
|
||||||
|
|
||||||
|
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||||
|
|
||||||
|
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
|
||||||
|
|
||||||
|
### Building ollama-mlx
|
||||||
|
|
||||||
|
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
|
||||||
|
|
||||||
|
#### macOS (Apple Silicon and Intel)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build MLX backend libraries
|
||||||
|
cmake --preset MLX
|
||||||
|
cmake --build --preset MLX --parallel
|
||||||
|
cmake --install build --component MLX
|
||||||
|
|
||||||
|
# Build ollama-mlx binary
|
||||||
|
go build -tags mlx -o ollama-mlx .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Linux (CUDA)
|
||||||
|
|
||||||
|
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build MLX backend libraries with CUDA support
|
||||||
|
cmake --preset 'MLX CUDA 13'
|
||||||
|
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||||
|
cmake --install build --component MLX
|
||||||
|
|
||||||
|
# Build ollama-mlx binary
|
||||||
|
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
|
||||||
|
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
|
||||||
|
go build -tags mlx -o ollama-mlx .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using build scripts
|
||||||
|
|
||||||
|
The build scripts automatically create the `ollama-mlx` binary:
|
||||||
|
|
||||||
|
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
|
||||||
|
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
|
||||||
|
|
||||||
|
## Image Generation
|
||||||
|
|
||||||
|
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.
|
||||||
67
x/cmd/run.go
@@ -25,6 +25,14 @@ import (
|
|||||||
"github.com/ollama/ollama/x/tools"
|
"github.com/ollama/ollama/x/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MultilineState tracks the state of multiline input
|
||||||
|
type MultilineState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
MultilineNone MultilineState = iota
|
||||||
|
MultilineSystem
|
||||||
|
)
|
||||||
|
|
||||||
// Tool output capping constants
|
// Tool output capping constants
|
||||||
const (
|
const (
|
||||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||||
@@ -648,7 +656,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
Prompt: ">>> ",
|
Prompt: ">>> ",
|
||||||
AltPrompt: "... ",
|
AltPrompt: "... ",
|
||||||
Placeholder: "Send a message (/? for help)",
|
Placeholder: "Send a message (/? for help)",
|
||||||
AltPlaceholder: "Press Enter to send",
|
AltPlaceholder: `Use """ to end multi-line input`,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -699,6 +707,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var format string
|
var format string
|
||||||
var system string
|
var system string
|
||||||
|
var multiline MultilineState = MultilineNone
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := scanner.Readline()
|
line, err := scanner.Readline()
|
||||||
@@ -712,12 +721,37 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
}
|
}
|
||||||
scanner.Prompt.UseAlt = false
|
scanner.Prompt.UseAlt = false
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
|
multiline = MultilineNone
|
||||||
continue
|
continue
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
case multiline != MultilineNone:
|
||||||
|
// check if there's a multiline terminating string
|
||||||
|
before, ok := strings.CutSuffix(line, `"""`)
|
||||||
|
sb.WriteString(before)
|
||||||
|
if !ok {
|
||||||
|
fmt.Fprintln(&sb)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch multiline {
|
||||||
|
case MultilineSystem:
|
||||||
|
system = sb.String()
|
||||||
|
newMessage := api.Message{Role: "system", Content: system}
|
||||||
|
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||||
|
messages[len(messages)-1] = newMessage
|
||||||
|
} else {
|
||||||
|
messages = append(messages, newMessage)
|
||||||
|
}
|
||||||
|
fmt.Println("Set system message.")
|
||||||
|
sb.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
multiline = MultilineNone
|
||||||
|
scanner.Prompt.UseAlt = false
|
||||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||||
return nil
|
return nil
|
||||||
case strings.HasPrefix(line, "/clear"):
|
case strings.HasPrefix(line, "/clear"):
|
||||||
@@ -826,18 +860,41 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
options[args[2]] = fp[args[2]]
|
options[args[2]] = fp[args[2]]
|
||||||
case "system":
|
case "system":
|
||||||
if len(args) < 3 {
|
if len(args) < 3 {
|
||||||
fmt.Println("Usage: /set system <message>")
|
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
system = strings.Join(args[2:], " ")
|
multiline = MultilineSystem
|
||||||
newMessage := api.Message{Role: "system", Content: system}
|
|
||||||
|
line := strings.Join(args[2:], " ")
|
||||||
|
line, ok := strings.CutPrefix(line, `"""`)
|
||||||
|
if !ok {
|
||||||
|
multiline = MultilineNone
|
||||||
|
} else {
|
||||||
|
// only cut suffix if the line is multiline
|
||||||
|
line, ok = strings.CutSuffix(line, `"""`)
|
||||||
|
if ok {
|
||||||
|
multiline = MultilineNone
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(line)
|
||||||
|
if multiline != MultilineNone {
|
||||||
|
scanner.Prompt.UseAlt = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
system = sb.String()
|
||||||
|
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||||
|
// Check if the slice is not empty and the last message is from 'system'
|
||||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||||
|
// Replace the last message
|
||||||
messages[len(messages)-1] = newMessage
|
messages[len(messages)-1] = newMessage
|
||||||
} else {
|
} else {
|
||||||
messages = append(messages, newMessage)
|
messages = append(messages, newMessage)
|
||||||
}
|
}
|
||||||
fmt.Println("Set system message.")
|
fmt.Println("Set system message.")
|
||||||
|
sb.Reset()
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||||
@@ -1024,7 +1081,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
sb.WriteString(line)
|
sb.WriteString(line)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sb.Len() > 0 {
|
if sb.Len() > 0 && multiline == MultilineNone {
|
||||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||||
messages = append(messages, newMessage)
|
messages = append(messages, newMessage)
|
||||||
|
|
||||||
|
|||||||
231
x/imagegen/api/handler.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RunnerScheduler is the interface for scheduling a model runner.
|
||||||
|
// This is implemented by server.Server to avoid circular imports.
|
||||||
|
type RunnerScheduler interface {
|
||||||
|
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes registers the image generation API routes.
|
||||||
|
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||||
|
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||||
|
ImageGenerationHandler(c, scheduler)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||||
|
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||||
|
var req ImageGenerationRequest
|
||||||
|
if err := c.BindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
if req.Model == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Prompt == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults
|
||||||
|
if req.N == 0 {
|
||||||
|
req.N = 1
|
||||||
|
}
|
||||||
|
if req.Size == "" {
|
||||||
|
req.Size = "1024x1024"
|
||||||
|
}
|
||||||
|
if req.ResponseFormat == "" {
|
||||||
|
req.ResponseFormat = "b64_json"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify model exists
|
||||||
|
if imagegen.ResolveModelName(req.Model) == "" {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse size
|
||||||
|
width, height := parseSize(req.Size)
|
||||||
|
|
||||||
|
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||||
|
opts := api.Options{}
|
||||||
|
opts.NumCtx = int(width)
|
||||||
|
opts.NumGPU = int(height)
|
||||||
|
|
||||||
|
// Schedule runner
|
||||||
|
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
status = http.StatusNotFound
|
||||||
|
}
|
||||||
|
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build completion request
|
||||||
|
completionReq := llm.CompletionRequest{
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
Options: &opts,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Stream {
|
||||||
|
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||||
|
} else {
|
||||||
|
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
|
||||||
|
var imageBase64 string
|
||||||
|
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||||
|
if resp.Done {
|
||||||
|
imageBase64 = extractBase64(resp.Content)
|
||||||
|
} else {
|
||||||
|
progress := parseProgress(resp.Content)
|
||||||
|
if progress.Total > 0 {
|
||||||
|
c.SSEvent("progress", progress)
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SSEvent("done", buildResponse(imageBase64, format))
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||||
|
var imageBase64 string
|
||||||
|
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||||
|
if resp.Done {
|
||||||
|
imageBase64 = extractBase64(resp.Content)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSize(size string) (int32, int32) {
|
||||||
|
parts := strings.Split(size, "x")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return 1024, 1024
|
||||||
|
}
|
||||||
|
w, _ := strconv.Atoi(parts[0])
|
||||||
|
h, _ := strconv.Atoi(parts[1])
|
||||||
|
if w == 0 {
|
||||||
|
w = 1024
|
||||||
|
}
|
||||||
|
if h == 0 {
|
||||||
|
h = 1024
|
||||||
|
}
|
||||||
|
return int32(w), int32(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBase64(content string) string {
|
||||||
|
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||||
|
return content[13:]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseProgress(content string) ImageProgressEvent {
|
||||||
|
var step, total int
|
||||||
|
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||||
|
return ImageProgressEvent{Step: step, Total: total}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||||
|
resp := ImageGenerationResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Data: make([]ImageData, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageBase64 == "" {
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
if format == "url" {
|
||||||
|
// URL format not supported when using base64 transfer
|
||||||
|
resp.Data[0].B64JSON = imageBase64
|
||||||
|
} else {
|
||||||
|
resp.Data[0].B64JSON = imageBase64
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||||
|
// This allows routes.go to delegate image generation with minimal code.
|
||||||
|
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||||
|
opts := api.Options{}
|
||||||
|
|
||||||
|
// Schedule runner
|
||||||
|
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build completion request
|
||||||
|
completionReq := llm.CompletionRequest{
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: &opts,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream responses via channel
|
||||||
|
ch := make(chan any)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||||
|
ch <- GenerateResponse{
|
||||||
|
Model: modelName,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Response: resp.Content,
|
||||||
|
Done: resp.Done,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// Log error but don't block - channel is already being consumed
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
streamFn(c, ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||||
|
type GenerateResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Response string `json:"response"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
}
|
||||||
31
x/imagegen/api/types.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// Package api provides OpenAI-compatible image generation API types.
|
||||||
|
package api
|
||||||
|
|
||||||
|
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||||
|
type ImageGenerationRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||||
|
type ImageGenerationResponse struct {
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []ImageData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageData contains the generated image data.
|
||||||
|
type ImageData struct {
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
B64JSON string `json:"b64_json,omitempty"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||||
|
type ImageProgressEvent struct {
|
||||||
|
Step int `json:"step"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ package imagegen
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -38,17 +39,77 @@ func DefaultOptions() ImageGenOptions {
|
|||||||
return ImageGenOptions{
|
return ImageGenOptions{
|
||||||
Width: 1024,
|
Width: 1024,
|
||||||
Height: 1024,
|
Height: 1024,
|
||||||
Steps: 0, // 0 means model default
|
Steps: 9,
|
||||||
Seed: 0, // 0 means random
|
Seed: 0, // 0 means random
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelInfo contains metadata about an image generation model.
|
||||||
|
type ModelInfo struct {
|
||||||
|
Architecture string
|
||||||
|
ParameterCount int64
|
||||||
|
Quantization string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelInfo returns metadata about an image generation model.
|
||||||
|
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||||
|
manifest, err := LoadManifest(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info := &ModelInfo{}
|
||||||
|
|
||||||
|
// Read model_index.json for architecture, parameter count, and quantization
|
||||||
|
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||||
|
var index struct {
|
||||||
|
Architecture string `json:"architecture"`
|
||||||
|
ParameterCount int64 `json:"parameter_count"`
|
||||||
|
Quantization string `json:"quantization"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &index) == nil {
|
||||||
|
info.Architecture = index.Architecture
|
||||||
|
info.ParameterCount = index.ParameterCount
|
||||||
|
info.Quantization = index.Quantization
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: detect quantization from tensor names if not in config
|
||||||
|
if info.Quantization == "" {
|
||||||
|
for _, layer := range manifest.Manifest.Layers {
|
||||||
|
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||||
|
info.Quantization = "FP8"
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info.Quantization == "" {
|
||||||
|
info.Quantization = "BF16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: estimate parameter count if not in config
|
||||||
|
if info.ParameterCount == 0 {
|
||||||
|
var totalSize int64
|
||||||
|
for _, layer := range manifest.Manifest.Layers {
|
||||||
|
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||||
|
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||||
|
totalSize += layer.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Assume BF16 (2 bytes/param) as rough estimate
|
||||||
|
info.ParameterCount = totalSize / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterFlags adds image generation flags to the given command.
|
// RegisterFlags adds image generation flags to the given command.
|
||||||
// Flags are hidden since they only apply to image generation models.
|
// Flags are hidden since they only apply to image generation models.
|
||||||
func RegisterFlags(cmd *cobra.Command) {
|
func RegisterFlags(cmd *cobra.Command) {
|
||||||
cmd.Flags().Int("width", 1024, "Image width")
|
cmd.Flags().Int("width", 1024, "Image width")
|
||||||
cmd.Flags().Int("height", 1024, "Image height")
|
cmd.Flags().Int("height", 1024, "Image height")
|
||||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||||
cmd.Flags().String("negative", "", "Negative prompt")
|
cmd.Flags().String("negative", "", "Negative prompt")
|
||||||
cmd.Flags().MarkHidden("width")
|
cmd.Flags().MarkHidden("width")
|
||||||
@@ -91,18 +152,23 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateImageWithOptions generates an image with the given options.
|
// generateImageWithOptions generates an image with the given options.
|
||||||
// Note: opts are currently unused as the native API doesn't support size parameters.
|
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||||
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
|
|
||||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build request with image gen options encoded in Options fields
|
||||||
|
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||||
req := &api.GenerateRequest{
|
req := &api.GenerateRequest{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
|
Options: map[string]any{
|
||||||
|
"num_ctx": opts.Width,
|
||||||
|
"num_gpu": opts.Height,
|
||||||
|
"num_predict": opts.Steps,
|
||||||
|
"seed": opts.Seed,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if keepAlive != nil {
|
if keepAlive != nil {
|
||||||
req.KeepAlive = keepAlive
|
req.KeepAlive = keepAlive
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||||
@@ -49,7 +48,7 @@ func main() {
|
|||||||
// Image generation params
|
// Image generation params
|
||||||
width := flag.Int("width", 1024, "Image width")
|
width := flag.Int("width", 1024, "Image width")
|
||||||
height := flag.Int("height", 1024, "Image height")
|
height := flag.Int("height", 1024, "Image height")
|
||||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
steps := flag.Int("steps", 9, "Denoising steps")
|
||||||
seed := flag.Int64("seed", 42, "Random seed")
|
seed := flag.Int64("seed", 42, "Random seed")
|
||||||
out := flag.String("output", "output.png", "Output path")
|
out := flag.String("output", "output.png", "Output path")
|
||||||
|
|
||||||
|
|||||||
@@ -175,63 +175,3 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelInfo contains metadata about an image generation model.
|
|
||||||
type ModelInfo struct {
|
|
||||||
Architecture string
|
|
||||||
ParameterCount int64
|
|
||||||
Quantization string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModelInfo returns metadata about an image generation model.
|
|
||||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
|
||||||
manifest, err := LoadManifest(modelName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
info := &ModelInfo{}
|
|
||||||
|
|
||||||
// Read model_index.json for architecture, parameter count, and quantization
|
|
||||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
|
||||||
var index struct {
|
|
||||||
Architecture string `json:"architecture"`
|
|
||||||
ParameterCount int64 `json:"parameter_count"`
|
|
||||||
Quantization string `json:"quantization"`
|
|
||||||
}
|
|
||||||
if json.Unmarshal(data, &index) == nil {
|
|
||||||
info.Architecture = index.Architecture
|
|
||||||
info.ParameterCount = index.ParameterCount
|
|
||||||
info.Quantization = index.Quantization
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: detect quantization from tensor names if not in config
|
|
||||||
if info.Quantization == "" {
|
|
||||||
for _, layer := range manifest.Manifest.Layers {
|
|
||||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
|
||||||
info.Quantization = "FP8"
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if info.Quantization == "" {
|
|
||||||
info.Quantization = "BF16"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: estimate parameter count if not in config
|
|
||||||
if info.ParameterCount == 0 {
|
|
||||||
var totalSize int64
|
|
||||||
for _, layer := range manifest.Manifest.Layers {
|
|
||||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
|
||||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
|
||||||
totalSize += layer.Size
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Assume BF16 (2 bytes/param) as rough estimate
|
|
||||||
info.ParameterCount = totalSize / 2
|
|
||||||
}
|
|
||||||
|
|
||||||
return info, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -95,3 +95,8 @@ func EstimateVRAM(modelName string) uint64 {
|
|||||||
}
|
}
|
||||||
return 21 * GB
|
return 21 * GB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasTensorLayers checks if the given model has tensor layers.
|
||||||
|
func HasTensorLayers(modelName string) bool {
|
||||||
|
return ResolveModelName(modelName) != ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -94,6 +94,13 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasTensorLayers(t *testing.T) {
|
||||||
|
// Non-existent model should return false
|
||||||
|
if HasTensorLayers("nonexistent-model") {
|
||||||
|
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResolveModelName(t *testing.T) {
|
func TestResolveModelName(t *testing.T) {
|
||||||
// Non-existent model should return empty string
|
// Non-existent model should return empty string
|
||||||
result := ResolveModelName("nonexistent-model")
|
result := ResolveModelName("nonexistent-model")
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
|
||||||
"github.com/ollama/ollama/x/imagegen/cache"
|
"github.com/ollama/ollama/x/imagegen/cache"
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||||
@@ -173,7 +172,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
|||||||
cfg.Height = 1024
|
cfg.Height = 1024
|
||||||
}
|
}
|
||||||
if cfg.Steps <= 0 {
|
if cfg.Steps <= 0 {
|
||||||
cfg.Steps = 50
|
cfg.Steps = 30
|
||||||
}
|
}
|
||||||
if cfg.CFGScale <= 0 {
|
if cfg.CFGScale <= 0 {
|
||||||
cfg.CFGScale = 4.0
|
cfg.CFGScale = 4.0
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
cfg.Height = 1024
|
cfg.Height = 1024
|
||||||
}
|
}
|
||||||
if cfg.Steps <= 0 {
|
if cfg.Steps <= 0 {
|
||||||
cfg.Steps = 9 // Z-Image turbo default
|
cfg.Steps = 9 // Turbo default
|
||||||
}
|
}
|
||||||
if cfg.CFGScale <= 0 {
|
if cfg.CFGScale <= 0 {
|
||||||
cfg.CFGScale = 4.0
|
cfg.CFGScale = 4.0
|
||||||
|
|||||||
@@ -136,8 +136,16 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Model applies its own defaults for width/height/steps
|
// Apply defaults
|
||||||
// Only seed needs to be set here if not provided
|
if req.Width <= 0 {
|
||||||
|
req.Width = 1024
|
||||||
|
}
|
||||||
|
if req.Height <= 0 {
|
||||||
|
req.Height = 1024
|
||||||
|
}
|
||||||
|
if req.Steps <= 0 {
|
||||||
|
req.Steps = 9
|
||||||
|
}
|
||||||
if req.Seed <= 0 {
|
if req.Seed <= 0 {
|
||||||
req.Seed = time.Now().UnixNano()
|
req.Seed = time.Now().UnixNano()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -26,11 +25,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||||
//
|
|
||||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
|
||||||
// like any other model. The plan is to eventually bring this into the llm/ package
|
|
||||||
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
|
|
||||||
// separate allows for independent iteration on image generation support.
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
@@ -43,6 +37,22 @@ type Server struct {
|
|||||||
lastErrLock sync.Mutex
|
lastErrLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// completionRequest is sent to the subprocess
|
||||||
|
type completionRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Width int32 `json:"width,omitempty"`
|
||||||
|
Height int32 `json:"height,omitempty"`
|
||||||
|
Steps int `json:"steps,omitempty"`
|
||||||
|
Seed int64 `json:"seed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// completionResponse is received from the subprocess
|
||||||
|
type completionResponse struct {
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
}
|
||||||
|
|
||||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||||
func NewServer(modelName string) (*Server, error) {
|
func NewServer(modelName string) (*Server, error) {
|
||||||
// Validate platform support before attempting to start
|
// Validate platform support before attempting to start
|
||||||
@@ -129,6 +139,7 @@ func NewServer(modelName string) (*Server, error) {
|
|||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
slog.Warn("image-runner", "msg", line)
|
slog.Warn("image-runner", "msg", line)
|
||||||
|
// Capture last error line for better error reporting
|
||||||
s.lastErrLock.Lock()
|
s.lastErrLock.Lock()
|
||||||
s.lastErr = line
|
s.lastErr = line
|
||||||
s.lastErrLock.Unlock()
|
s.lastErrLock.Unlock()
|
||||||
@@ -160,6 +171,7 @@ func (s *Server) ModelPath() string {
|
|||||||
return s.modelName
|
return s.modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load is called by the scheduler after the server is created.
|
||||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -192,16 +204,20 @@ func (s *Server) waitUntilRunning() error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case err := <-s.done:
|
case err := <-s.done:
|
||||||
// Include recent stderr lines for better error context
|
// Include last stderr line for better error context
|
||||||
errMsg := s.getLastErr()
|
s.lastErrLock.Lock()
|
||||||
if errMsg != "" {
|
lastErr := s.lastErr
|
||||||
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
|
s.lastErrLock.Unlock()
|
||||||
|
if lastErr != "" {
|
||||||
|
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
errMsg := s.getLastErr()
|
s.lastErrLock.Lock()
|
||||||
if errMsg != "" {
|
lastErr := s.lastErr
|
||||||
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
|
s.lastErrLock.Unlock()
|
||||||
|
if lastErr != "" {
|
||||||
|
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||||
}
|
}
|
||||||
return errors.New("timeout waiting for image runner to start")
|
return errors.New("timeout waiting for image runner to start")
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
@@ -213,39 +229,44 @@ func (s *Server) waitUntilRunning() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLastErr returns the last stderr line.
|
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||||
func (s *Server) getLastErr() string {
|
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||||
s.lastErrLock.Lock()
|
return nil
|
||||||
defer s.lastErrLock.Unlock()
|
|
||||||
return s.lastErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
|
// Completion generates an image from the prompt via the subprocess.
|
||||||
|
|
||||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
seed := req.Seed
|
// Build request
|
||||||
if seed == 0 {
|
creq := completionRequest{
|
||||||
seed = time.Now().UnixNano()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request for subprocess
|
|
||||||
creq := struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Width int32 `json:"width,omitempty"`
|
|
||||||
Height int32 `json:"height,omitempty"`
|
|
||||||
Seed int64 `json:"seed,omitempty"`
|
|
||||||
}{
|
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Width: req.Width,
|
Width: 1024,
|
||||||
Height: req.Height,
|
Height: 1024,
|
||||||
Seed: seed,
|
Steps: 9,
|
||||||
|
Seed: time.Now().UnixNano(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Options != nil {
|
||||||
|
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||||
|
creq.Width = int32(req.Options.NumCtx)
|
||||||
|
}
|
||||||
|
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||||
|
creq.Height = int32(req.Options.NumGPU)
|
||||||
|
}
|
||||||
|
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||||
|
creq.Steps = req.Options.NumPredict
|
||||||
|
}
|
||||||
|
if req.Options.Seed > 0 {
|
||||||
|
creq.Seed = int64(req.Options.Seed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode request body
|
||||||
body, err := json.Marshal(creq)
|
body, err := json.Marshal(creq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send request to subprocess
|
||||||
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -260,40 +281,30 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("request failed: %d", resp.StatusCode)
|
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stream responses - use large buffer for base64 image data
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
// Parse subprocess response (has singular "image" field)
|
var cresp completionResponse
|
||||||
var raw struct {
|
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||||
Image string `json:"image,omitempty"`
|
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
Done bool `json:"done"`
|
|
||||||
Step int `json:"step,omitempty"`
|
|
||||||
Total int `json:"total,omitempty"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to llm.CompletionResponse
|
content := cresp.Content
|
||||||
cresp := llm.CompletionResponse{
|
// If this is the final response with an image, encode it in the content
|
||||||
Content: raw.Content,
|
if cresp.Done && cresp.Image != "" {
|
||||||
Done: raw.Done,
|
content = "IMAGE_BASE64:" + cresp.Image
|
||||||
Step: raw.Step,
|
|
||||||
Total: raw.Total,
|
|
||||||
}
|
|
||||||
if raw.Image != "" {
|
|
||||||
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
|
|
||||||
cresp.Image = data
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(llm.CompletionResponse{
|
||||||
|
Content: content,
|
||||||
|
Done: cresp.Done,
|
||||||
|
})
|
||||||
if cresp.Done {
|
if cresp.Done {
|
||||||
return nil
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,18 +346,22 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
|||||||
return s.vramSize
|
return s.vramSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Embedding is not supported for image generation models.
|
||||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||||
return nil, 0, errors.New("not supported")
|
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tokenize is not supported for image generation models.
|
||||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
return nil, errors.New("not supported")
|
return nil, errors.New("tokenize not supported for image generation models")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detokenize is not supported for image generation models.
|
||||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
return "", errors.New("not supported")
|
return "", errors.New("detokenize not supported for image generation models")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pid returns the subprocess PID.
|
||||||
func (s *Server) Pid() int {
|
func (s *Server) Pid() int {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -356,9 +371,17 @@ func (s *Server) Pid() int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GetPort() int { return s.port }
|
// GetPort returns the subprocess port.
|
||||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
func (s *Server) GetPort() int {
|
||||||
|
return s.port
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||||
|
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasExited returns true if the subprocess has exited.
|
||||||
func (s *Server) HasExited() bool {
|
func (s *Server) HasExited() bool {
|
||||||
select {
|
select {
|
||||||
case <-s.done:
|
case <-s.done:
|
||||||
|
|||||||