mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-17 13:28:31 -04:00
feat(api): add ollama compatibility (#9284)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
b0d9ce4905
commit
85be4ff03c
@@ -62,6 +62,7 @@ type RunCMD struct {
|
||||
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
||||
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
||||
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
|
||||
OllamaAPIRootEndpoint bool `env:"LOCALAI_OLLAMA_API_ROOT_ENDPOINT" default:"false" help:"Register Ollama-compatible health check on / (replaces web UI on root path). The /api/* Ollama endpoints are always available regardless of this flag" group:"api"`
|
||||
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||
@@ -295,6 +296,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.DisableWebUI)
|
||||
}
|
||||
|
||||
if r.OllamaAPIRootEndpoint {
|
||||
opts = append(opts, config.EnableOllamaAPIRootEndpoint)
|
||||
}
|
||||
|
||||
if r.DisableGalleryEndpoint {
|
||||
opts = append(opts, config.DisableGalleryEndpoint)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ type ApplicationConfig struct {
|
||||
Federated bool
|
||||
|
||||
DisableWebUI bool
|
||||
OllamaAPIRootEndpoint bool
|
||||
EnforcePredownloadScans bool
|
||||
OpaqueErrors bool
|
||||
UseSubtleKeyComparison bool
|
||||
@@ -263,6 +264,10 @@ var DisableWebUI = func(o *ApplicationConfig) {
|
||||
o.DisableWebUI = true
|
||||
}
|
||||
|
||||
var EnableOllamaAPIRootEndpoint = func(o *ApplicationConfig) {
|
||||
o.OllamaAPIRootEndpoint = true
|
||||
}
|
||||
|
||||
var DisableRuntimeSettings = func(o *ApplicationConfig) {
|
||||
o.DisableRuntimeSettings = true
|
||||
}
|
||||
|
||||
@@ -391,6 +391,10 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
routes.RegisterAnthropicRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOpenResponsesRoutes(e, requestExtractor, application)
|
||||
routes.RegisterOllamaRoutes(e, requestExtractor, application)
|
||||
if application.ApplicationConfig().OllamaAPIRootEndpoint {
|
||||
routes.RegisterOllamaRootEndpoint(e)
|
||||
}
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware)
|
||||
|
||||
153
core/http/endpoints/ollama/chat.go
Normal file
153
core/http/endpoints/ollama/chat.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ChatEndpoint handles Ollama-compatible /api/chat requests
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OllamaChatRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return ollamaError(c, 400, "model is required")
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return ollamaError(c, 400, "model configuration not found")
|
||||
}
|
||||
|
||||
// Apply Ollama options to config
|
||||
applyOllamaOptions(input.Options, cfg)
|
||||
|
||||
// Convert Ollama messages to OpenAI format
|
||||
openAIMessages := ollamaMessagesToOpenAI(input.Messages)
|
||||
|
||||
// Build an OpenAI-compatible request
|
||||
openAIReq := &schema.OpenAIRequest{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: input.Model},
|
||||
},
|
||||
Messages: openAIMessages,
|
||||
Stream: input.IsStream(),
|
||||
Context: input.Context,
|
||||
Cancel: input.Cancel,
|
||||
}
|
||||
|
||||
if input.Options != nil {
|
||||
openAIReq.Temperature = input.Options.Temperature
|
||||
openAIReq.TopP = input.Options.TopP
|
||||
openAIReq.TopK = input.Options.TopK
|
||||
openAIReq.RepeatPenalty = input.Options.RepeatPenalty
|
||||
if input.Options.NumPredict != nil {
|
||||
openAIReq.Maxtokens = input.Options.NumPredict
|
||||
}
|
||||
if len(input.Options.Stop) > 0 {
|
||||
openAIReq.Stop = input.Options.Stop
|
||||
}
|
||||
}
|
||||
|
||||
predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, nil, false)
|
||||
xlog.Debug("Ollama Chat - Prompt (after templating)", "prompt_len", len(predInput))
|
||||
|
||||
if input.IsStream() {
|
||||
return handleOllamaChatStream(c, input, cfg, ml, cl, appConfig, predInput, openAIReq)
|
||||
}
|
||||
|
||||
return handleOllamaChatNonStream(c, input, cfg, ml, cl, appConfig, predInput, openAIReq)
|
||||
}
|
||||
}
|
||||
|
||||
func handleOllamaChatNonStream(c echo.Context, input *schema.OllamaChatRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest) error {
|
||||
startTime := time.Now()
|
||||
var result string
|
||||
|
||||
cb := func(s string, choices *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
|
||||
_, tokenUsage, _, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama chat inference failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
|
||||
resp := schema.OllamaChatResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: schema.OllamaMessage{
|
||||
Role: "assistant",
|
||||
Content: result,
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: tokenUsage.Prompt,
|
||||
EvalCount: tokenUsage.Completion,
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
func handleOllamaChatStream(c echo.Context, input *schema.OllamaChatRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest) error {
|
||||
c.Response().Header().Set("Content-Type", "application/x-ndjson")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
tokenCallback := func(token string, usage backend.TokenUsage) bool {
|
||||
chunk := schema.OllamaChatResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: schema.OllamaMessage{
|
||||
Role: "assistant",
|
||||
Content: token,
|
||||
},
|
||||
Done: false,
|
||||
}
|
||||
return writeNDJSON(c, chunk)
|
||||
}
|
||||
|
||||
_, tokenUsage, _, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, choices *[]schema.Choice) {}, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama chat stream inference failed", "error", err)
|
||||
errChunk := schema.OllamaChatResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "error",
|
||||
}
|
||||
writeNDJSON(c, errChunk)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send final done message
|
||||
totalDuration := time.Since(startTime)
|
||||
finalChunk := schema.OllamaChatResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: schema.OllamaMessage{Role: "assistant", Content: ""},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: tokenUsage.Prompt,
|
||||
EvalCount: tokenUsage.Completion,
|
||||
}
|
||||
writeNDJSON(c, finalChunk)
|
||||
|
||||
return nil
|
||||
}
|
||||
67
core/http/endpoints/ollama/embed.go
Normal file
67
core/http/endpoints/ollama/embed.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// EmbedEndpoint handles Ollama-compatible /api/embed and /api/embeddings requests
|
||||
func EmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OllamaEmbedRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return ollamaError(c, 400, "model is required")
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return ollamaError(c, 400, "model configuration not found")
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
inputStrings := input.GetInputStrings()
|
||||
if len(inputStrings) == 0 {
|
||||
return ollamaError(c, 400, "input is required")
|
||||
}
|
||||
|
||||
var allEmbeddings [][]float32
|
||||
promptEvalCount := 0
|
||||
|
||||
for _, s := range inputStrings {
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *cfg, appConfig)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama embed failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("embedding failed: %v", err))
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
xlog.Error("Ollama embed computation failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("embedding computation failed: %v", err))
|
||||
}
|
||||
|
||||
allEmbeddings = append(allEmbeddings, embeddings)
|
||||
// Rough token count estimate
|
||||
promptEvalCount += len(s) / 4
|
||||
}
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
|
||||
resp := schema.OllamaEmbedResponse{
|
||||
Model: input.Model,
|
||||
Embeddings: allEmbeddings,
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: promptEvalCount,
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
179
core/http/endpoints/ollama/generate.go
Normal file
179
core/http/endpoints/ollama/generate.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// GenerateEndpoint handles Ollama-compatible /api/generate requests
|
||||
func GenerateEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OllamaGenerateRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return ollamaError(c, 400, "model is required")
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return ollamaError(c, 400, "model configuration not found")
|
||||
}
|
||||
|
||||
// Handle empty prompt — return immediately with "load" reason
|
||||
if input.Prompt == "" {
|
||||
resp := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: "",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if input.IsStream() {
|
||||
c.Response().Header().Set("Content-Type", "application/x-ndjson")
|
||||
writeNDJSON(c, resp)
|
||||
return nil
|
||||
}
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
applyOllamaOptions(input.Options, cfg)
|
||||
|
||||
// Build messages from prompt
|
||||
var messages []schema.Message
|
||||
if input.System != "" {
|
||||
messages = append(messages, schema.Message{
|
||||
Role: "system",
|
||||
StringContent: input.System,
|
||||
Content: input.System,
|
||||
})
|
||||
}
|
||||
messages = append(messages, schema.Message{
|
||||
Role: "user",
|
||||
StringContent: input.Prompt,
|
||||
Content: input.Prompt,
|
||||
})
|
||||
|
||||
openAIReq := &schema.OpenAIRequest{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: input.Model},
|
||||
},
|
||||
Messages: messages,
|
||||
Stream: input.IsStream(),
|
||||
Context: input.Ctx,
|
||||
Cancel: input.Cancel,
|
||||
}
|
||||
|
||||
if input.Options != nil {
|
||||
openAIReq.Temperature = input.Options.Temperature
|
||||
openAIReq.TopP = input.Options.TopP
|
||||
openAIReq.TopK = input.Options.TopK
|
||||
openAIReq.RepeatPenalty = input.Options.RepeatPenalty
|
||||
if input.Options.NumPredict != nil {
|
||||
openAIReq.Maxtokens = input.Options.NumPredict
|
||||
}
|
||||
if len(input.Options.Stop) > 0 {
|
||||
openAIReq.Stop = input.Options.Stop
|
||||
}
|
||||
}
|
||||
|
||||
var predInput string
|
||||
if input.Raw {
|
||||
// Raw mode: skip chat template, use prompt directly
|
||||
predInput = input.Prompt
|
||||
} else {
|
||||
predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, nil, false)
|
||||
}
|
||||
xlog.Debug("Ollama Generate - Prompt", "prompt_len", len(predInput), "raw", input.Raw)
|
||||
|
||||
if input.IsStream() {
|
||||
return handleOllamaGenerateStream(c, input, cfg, ml, cl, appConfig, predInput, openAIReq)
|
||||
}
|
||||
|
||||
return handleOllamaGenerateNonStream(c, input, cfg, ml, cl, appConfig, predInput, openAIReq)
|
||||
}
|
||||
}
|
||||
|
||||
func handleOllamaGenerateNonStream(c echo.Context, input *schema.OllamaGenerateRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest) error {
|
||||
startTime := time.Now()
|
||||
var result string
|
||||
|
||||
cb := func(s string, choices *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
|
||||
_, tokenUsage, _, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama generate inference failed", "error", err)
|
||||
return ollamaError(c, 500, fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
|
||||
resp := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: result,
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: tokenUsage.Prompt,
|
||||
EvalCount: tokenUsage.Completion,
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
func handleOllamaGenerateStream(c echo.Context, input *schema.OllamaGenerateRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest) error {
|
||||
c.Response().Header().Set("Content-Type", "application/x-ndjson")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
tokenCallback := func(token string, usage backend.TokenUsage) bool {
|
||||
chunk := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: token,
|
||||
Done: false,
|
||||
}
|
||||
return writeNDJSON(c, chunk)
|
||||
}
|
||||
|
||||
_, tokenUsage, _, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, choices *[]schema.Choice) {}, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Ollama generate stream inference failed", "error", err)
|
||||
errChunk := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "error",
|
||||
}
|
||||
writeNDJSON(c, errChunk)
|
||||
return nil
|
||||
}
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
finalChunk := schema.OllamaGenerateResponse{
|
||||
Model: input.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: "",
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
TotalDuration: totalDuration.Nanoseconds(),
|
||||
PromptEvalCount: tokenUsage.Prompt,
|
||||
EvalCount: tokenUsage.Completion,
|
||||
}
|
||||
writeNDJSON(c, finalChunk)
|
||||
|
||||
return nil
|
||||
}
|
||||
83
core/http/endpoints/ollama/helpers.go
Normal file
83
core/http/endpoints/ollama/helpers.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// writeNDJSON writes a JSON object followed by a newline to the response (NDJSON format)
|
||||
func writeNDJSON(c echo.Context, v any) bool {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to marshal NDJSON", "error", err)
|
||||
return false
|
||||
}
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "%s\n", data)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
c.Response().Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
// ollamaError sends an Ollama-compatible JSON error response
|
||||
func ollamaError(c echo.Context, statusCode int, message string) error {
|
||||
return c.JSON(statusCode, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
// applyOllamaOptions applies Ollama options to the model configuration
|
||||
func applyOllamaOptions(opts *schema.OllamaOptions, cfg *config.ModelConfig) {
|
||||
if opts == nil {
|
||||
return
|
||||
}
|
||||
if opts.Temperature != nil {
|
||||
cfg.Temperature = opts.Temperature
|
||||
}
|
||||
if opts.TopP != nil {
|
||||
cfg.TopP = opts.TopP
|
||||
}
|
||||
if opts.TopK != nil {
|
||||
cfg.TopK = opts.TopK
|
||||
}
|
||||
if opts.NumPredict != nil {
|
||||
cfg.Maxtokens = opts.NumPredict
|
||||
}
|
||||
if opts.RepeatPenalty != 0 {
|
||||
cfg.RepeatPenalty = opts.RepeatPenalty
|
||||
}
|
||||
if opts.RepeatLastN != 0 {
|
||||
cfg.RepeatLastN = opts.RepeatLastN
|
||||
}
|
||||
if len(opts.Stop) > 0 {
|
||||
cfg.StopWords = append(cfg.StopWords, opts.Stop...)
|
||||
}
|
||||
if opts.NumCtx > 0 {
|
||||
cfg.ContextSize = &opts.NumCtx
|
||||
}
|
||||
}
|
||||
|
||||
// ollamaMessagesToOpenAI converts Ollama messages to OpenAI-compatible messages
|
||||
func ollamaMessagesToOpenAI(messages []schema.OllamaMessage) []schema.Message {
|
||||
var result []schema.Message
|
||||
for _, msg := range messages {
|
||||
openAIMsg := schema.Message{
|
||||
Role: msg.Role,
|
||||
StringContent: msg.Content,
|
||||
Content: msg.Content,
|
||||
}
|
||||
|
||||
// Convert base64 images to data URIs
|
||||
for _, img := range msg.Images {
|
||||
dataURI := fmt.Sprintf("data:image/png;base64,%s", img)
|
||||
openAIMsg.StringImages = append(openAIMsg.StringImages, dataURI)
|
||||
}
|
||||
|
||||
result = append(result, openAIMsg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
142
core/http/endpoints/ollama/models.go
Normal file
142
core/http/endpoints/ollama/models.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
const ollamaCompatVersion = "0.9.0"
|
||||
|
||||
// ListModelsEndpoint handles Ollama-compatible GET /api/tags
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelNames, err := galleryop.ListModels(bcl, ml, nil, galleryop.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
return ollamaError(c, 500, fmt.Sprintf("failed to list models: %v", err))
|
||||
}
|
||||
|
||||
var models []schema.OllamaModelEntry
|
||||
for _, name := range modelNames {
|
||||
ollamaName := name
|
||||
if !strings.Contains(ollamaName, ":") {
|
||||
ollamaName += ":latest"
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", sha256.Sum256([]byte(name)))
|
||||
|
||||
entry := schema.OllamaModelEntry{
|
||||
Name: ollamaName,
|
||||
Model: ollamaName,
|
||||
ModifiedAt: time.Now().UTC(),
|
||||
Size: 0,
|
||||
Digest: digest,
|
||||
Details: modelDetailsFromConfig(bcl, name),
|
||||
}
|
||||
models = append(models, entry)
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.OllamaListResponse{Models: models})
|
||||
}
|
||||
}
|
||||
|
||||
// ShowModelEndpoint handles Ollama-compatible POST /api/show
|
||||
func ShowModelEndpoint(bcl *config.ModelConfigLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.OllamaShowRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return ollamaError(c, 400, "invalid request body")
|
||||
}
|
||||
|
||||
name := req.Name
|
||||
if name == "" {
|
||||
name = req.Model
|
||||
}
|
||||
if name == "" {
|
||||
return ollamaError(c, 400, "name is required")
|
||||
}
|
||||
|
||||
// Strip tag suffix for config lookup
|
||||
configName := strings.Split(name, ":")[0]
|
||||
|
||||
cfg, exists := bcl.GetModelConfig(configName)
|
||||
if !exists {
|
||||
return ollamaError(c, 404, fmt.Sprintf("model '%s' not found", name))
|
||||
}
|
||||
|
||||
resp := schema.OllamaShowResponse{
|
||||
Modelfile: fmt.Sprintf("FROM %s", cfg.Model),
|
||||
Parameters: "",
|
||||
Template: cfg.TemplateConfig.Chat,
|
||||
Details: modelDetailsFromModelConfig(&cfg),
|
||||
}
|
||||
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// ListRunningEndpoint handles Ollama-compatible GET /api/ps
|
||||
func ListRunningEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
loadedModels := ml.ListLoadedModels()
|
||||
|
||||
var models []schema.OllamaPsEntry
|
||||
for _, m := range loadedModels {
|
||||
name := m.ID
|
||||
ollamaName := name
|
||||
if !strings.Contains(ollamaName, ":") {
|
||||
ollamaName += ":latest"
|
||||
}
|
||||
|
||||
entry := schema.OllamaPsEntry{
|
||||
Name: ollamaName,
|
||||
Model: ollamaName,
|
||||
Size: 0,
|
||||
Digest: fmt.Sprintf("sha256:%x", sha256.Sum256([]byte(name))),
|
||||
Details: modelDetailsFromConfig(bcl, name),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).UTC(),
|
||||
SizeVRAM: 0,
|
||||
}
|
||||
models = append(models, entry)
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.OllamaPsResponse{Models: models})
|
||||
}
|
||||
}
|
||||
|
||||
// VersionEndpoint handles Ollama-compatible GET /api/version
|
||||
func VersionEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, schema.OllamaVersionResponse{Version: ollamaCompatVersion})
|
||||
}
|
||||
}
|
||||
|
||||
// HeartbeatEndpoint handles the Ollama root health check
|
||||
func HeartbeatEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.String(200, "Ollama is running")
|
||||
}
|
||||
}
|
||||
|
||||
func modelDetailsFromConfig(bcl *config.ModelConfigLoader, name string) schema.OllamaModelDetails {
|
||||
configName := strings.Split(name, ":")[0]
|
||||
cfg, exists := bcl.GetModelConfig(configName)
|
||||
if !exists {
|
||||
return schema.OllamaModelDetails{}
|
||||
}
|
||||
return modelDetailsFromModelConfig(&cfg)
|
||||
}
|
||||
|
||||
func modelDetailsFromModelConfig(cfg *config.ModelConfig) schema.OllamaModelDetails {
|
||||
return schema.OllamaModelDetails{
|
||||
Format: "gguf",
|
||||
Family: cfg.Backend,
|
||||
}
|
||||
}
|
||||
62
core/http/endpoints/ollama/models_test.go
Normal file
62
core/http/endpoints/ollama/models_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package ollama_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/ollama"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestOllamaEndpoints(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Ollama Endpoints Suite")
|
||||
}
|
||||
|
||||
var _ = Describe("Ollama endpoint handlers", func() {
|
||||
var e *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
})
|
||||
|
||||
Describe("HeartbeatEndpoint", func() {
|
||||
It("returns 'Ollama is running' on GET /", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := ollama.HeartbeatEndpoint()
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Body.String()).To(Equal("Ollama is running"))
|
||||
})
|
||||
|
||||
It("returns 200 on HEAD /", func() {
|
||||
req := httptest.NewRequest(http.MethodHead, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := ollama.HeartbeatEndpoint()
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("VersionEndpoint", func() {
|
||||
It("returns a JSON object with version field", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := ollama.VersionEndpoint()
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Body.String()).To(ContainSubstring(`"version"`))
|
||||
Expect(rec.Body.String()).To(MatchRegexp(`\d+\.\d+\.\d+`))
|
||||
})
|
||||
})
|
||||
})
|
||||
165
core/http/routes/ollama.go
Normal file
165
core/http/routes/ollama.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/ollama"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
func RegisterOllamaRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
application *application.Application) {
|
||||
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.AuthDB())
|
||||
|
||||
// Chat endpoint: POST /api/chat
|
||||
chatHandler := ollama.ChatEndpoint(
|
||||
application.ModelConfigLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OllamaChatRequest) }),
|
||||
setOllamaChatRequestContext(application.ApplicationConfig()),
|
||||
}
|
||||
app.POST("/api/chat", chatHandler, chatMiddleware...)
|
||||
|
||||
// Generate endpoint: POST /api/generate
|
||||
generateHandler := ollama.GenerateEndpoint(
|
||||
application.ModelConfigLoader(),
|
||||
application.ModelLoader(),
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
generateMiddleware := []echo.MiddlewareFunc{
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OllamaGenerateRequest) }),
|
||||
setOllamaGenerateRequestContext(application.ApplicationConfig()),
|
||||
}
|
||||
app.POST("/api/generate", generateHandler, generateMiddleware...)
|
||||
|
||||
// Embed endpoints: POST /api/embed and /api/embeddings
|
||||
embedHandler := ollama.EmbedEndpoint(
|
||||
application.ModelConfigLoader(),
|
||||
application.ModelLoader(),
|
||||
application.ApplicationConfig(),
|
||||
)
|
||||
embedMiddleware := []echo.MiddlewareFunc{
|
||||
usageMiddleware,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OllamaEmbedRequest) }),
|
||||
}
|
||||
app.POST("/api/embed", embedHandler, embedMiddleware...)
|
||||
app.POST("/api/embeddings", embedHandler, embedMiddleware...)
|
||||
|
||||
// Model management endpoints (no model-specific middleware needed)
|
||||
app.GET("/api/tags", ollama.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader()))
|
||||
app.HEAD("/api/tags", ollama.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader()))
|
||||
app.POST("/api/show", ollama.ShowModelEndpoint(application.ModelConfigLoader()))
|
||||
app.GET("/api/ps", ollama.ListRunningEndpoint(application.ModelConfigLoader(), application.ModelLoader()))
|
||||
app.GET("/api/version", ollama.VersionEndpoint())
|
||||
app.HEAD("/api/version", ollama.VersionEndpoint())
|
||||
}
|
||||
|
||||
// RegisterOllamaRootEndpoint registers the Ollama "/" health check.
|
||||
// This is separate because it conflicts with the web UI and is gated behind a CLI flag.
|
||||
func RegisterOllamaRootEndpoint(app *echo.Echo) {
|
||||
app.GET("/", ollama.HeartbeatEndpoint())
|
||||
app.HEAD("/", ollama.HeartbeatEndpoint())
|
||||
}
|
||||
|
||||
// setOllamaChatRequestContext sets up context and cancellation for Ollama chat requests
|
||||
func setOllamaChatRequestContext(appConfig *config.ApplicationConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OllamaChatRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
correlationID := uuid.New().String()
|
||||
c.Response().Header().Set("X-Correlation-ID", correlationID)
|
||||
|
||||
reqCtx := c.Request().Context()
|
||||
c1, cancel := context.WithCancel(appConfig.Context)
|
||||
stop := context.AfterFunc(reqCtx, cancel)
|
||||
defer func() {
|
||||
stop()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
ctxWithCorrelationID := context.WithValue(c1, middleware.CorrelationIDKey, correlationID)
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
if cfg.Model == "" {
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setOllamaGenerateRequestContext sets up context and cancellation for Ollama generate requests
|
||||
func setOllamaGenerateRequestContext(appConfig *config.ApplicationConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OllamaGenerateRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
correlationID := uuid.New().String()
|
||||
c.Response().Header().Set("X-Correlation-ID", correlationID)
|
||||
|
||||
reqCtx := c.Request().Context()
|
||||
c1, cancel := context.WithCancel(appConfig.Context)
|
||||
stop := context.AfterFunc(reqCtx, cancel)
|
||||
defer func() {
|
||||
stop()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
ctxWithCorrelationID := context.WithValue(c1, middleware.CorrelationIDKey, correlationID)
|
||||
input.Ctx = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
if cfg.Model == "" {
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
257
core/schema/ollama.go
Normal file
257
core/schema/ollama.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OllamaOptions represents runtime parameters for Ollama generation
|
||||
type OllamaOptions struct {
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
NumPredict *int `json:"num_predict,omitempty"`
|
||||
RepeatPenalty float64 `json:"repeat_penalty,omitempty"`
|
||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaMessage represents a message in Ollama chat format
|
||||
type OllamaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
ToolCalls []any `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaChatRequest represents a request to the Ollama Chat API
|
||||
type OllamaChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OllamaMessage `json:"messages"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Format any `json:"format,omitempty"`
|
||||
Options *OllamaOptions `json:"options,omitempty"`
|
||||
Tools []any `json:"tools,omitempty"`
|
||||
|
||||
// Internal fields
|
||||
Context context.Context `json:"-"`
|
||||
Cancel context.CancelFunc `json:"-"`
|
||||
}
|
||||
|
||||
// ModelName implements the LocalAIRequest interface
|
||||
func (r *OllamaChatRequest) ModelName(s *string) string {
|
||||
if s != nil {
|
||||
r.Model = *s
|
||||
}
|
||||
return r.Model
|
||||
}
|
||||
|
||||
// IsStream returns whether streaming is enabled (defaults to true for Ollama)
|
||||
func (r *OllamaChatRequest) IsStream() bool {
|
||||
if r.Stream == nil {
|
||||
return true
|
||||
}
|
||||
return *r.Stream
|
||||
}
|
||||
|
||||
// OllamaChatResponse represents a response from the Ollama Chat API
|
||||
type OllamaChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message OllamaMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
TotalDuration int64 `json:"total_duration,omitempty"`
|
||||
LoadDuration int64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaGenerateRequest represents a request to the Ollama Generate API
|
||||
type OllamaGenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format any `json:"format,omitempty"`
|
||||
Options *OllamaOptions `json:"options,omitempty"`
|
||||
// Context from a previous generate call for continuation
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
// Internal fields
|
||||
Ctx context.Context `json:"-"`
|
||||
Cancel context.CancelFunc `json:"-"`
|
||||
}
|
||||
|
||||
// ModelName implements the LocalAIRequest interface
|
||||
func (r *OllamaGenerateRequest) ModelName(s *string) string {
|
||||
if s != nil {
|
||||
r.Model = *s
|
||||
}
|
||||
return r.Model
|
||||
}
|
||||
|
||||
// IsStream returns whether streaming is enabled (defaults to true for Ollama)
|
||||
func (r *OllamaGenerateRequest) IsStream() bool {
|
||||
if r.Stream == nil {
|
||||
return true
|
||||
}
|
||||
return *r.Stream
|
||||
}
|
||||
|
||||
// OllamaGenerateResponse represents a response from the Ollama Generate API
|
||||
type OllamaGenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
TotalDuration int64 `json:"total_duration,omitempty"`
|
||||
LoadDuration int64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaEmbedRequest represents a request to the Ollama Embed API
|
||||
type OllamaEmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"` // string or []string
|
||||
Options *OllamaOptions `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
// ModelName implements the LocalAIRequest interface
|
||||
func (r *OllamaEmbedRequest) ModelName(s *string) string {
|
||||
if s != nil {
|
||||
r.Model = *s
|
||||
}
|
||||
return r.Model
|
||||
}
|
||||
|
||||
// GetInputStrings normalizes the Input field to a string slice
|
||||
func (r *OllamaEmbedRequest) GetInputStrings() []string {
|
||||
switch v := r.Input.(type) {
|
||||
case string:
|
||||
return []string{v}
|
||||
case []any:
|
||||
var result []string
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
case []string:
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OllamaEmbedResponse represents a response from the Ollama Embed API
|
||||
type OllamaEmbedResponse struct {
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float32 `json:"embeddings"`
|
||||
TotalDuration int64 `json:"total_duration,omitempty"`
|
||||
LoadDuration int64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaShowRequest represents a request to the Ollama Show API
|
||||
type OllamaShowRequest struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Verbose bool `json:"verbose,omitempty"`
|
||||
}
|
||||
|
||||
// ModelName implements the LocalAIRequest interface
|
||||
func (r *OllamaShowRequest) ModelName(s *string) string {
|
||||
name := r.Name
|
||||
if name == "" {
|
||||
name = r.Model
|
||||
}
|
||||
if s != nil {
|
||||
r.Name = *s
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// OllamaShowResponse represents a response from the Ollama Show API
|
||||
type OllamaShowResponse struct {
|
||||
Modelfile string `json:"modelfile"`
|
||||
Parameters string `json:"parameters"`
|
||||
Template string `json:"template"`
|
||||
License string `json:"license,omitempty"`
|
||||
Details OllamaModelDetails `json:"details"`
|
||||
}
|
||||
|
||||
// OllamaModelDetails contains model metadata
|
||||
type OllamaModelDetails struct {
|
||||
ParentModel string `json:"parent_model,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Families []string `json:"families,omitempty"`
|
||||
ParameterSize string `json:"parameter_size,omitempty"`
|
||||
QuantizationLevel string `json:"quantization_level,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaModelEntry represents a model in the list response
|
||||
type OllamaModelEntry struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details OllamaModelDetails `json:"details"`
|
||||
}
|
||||
|
||||
// OllamaListResponse represents a response from the Ollama Tags API
|
||||
type OllamaListResponse struct {
|
||||
Models []OllamaModelEntry `json:"models"`
|
||||
}
|
||||
|
||||
// OllamaPsEntry represents a running model in the ps response
|
||||
type OllamaPsEntry struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details OllamaModelDetails `json:"details"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
}
|
||||
|
||||
// OllamaPsResponse represents a response from the Ollama Ps API
|
||||
type OllamaPsResponse struct {
|
||||
Models []OllamaPsEntry `json:"models"`
|
||||
}
|
||||
|
||||
// OllamaVersionResponse represents a response from the Ollama Version API
|
||||
type OllamaVersionResponse struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// OllamaPullRequest represents a request to pull a model
|
||||
type OllamaPullRequest struct {
|
||||
Name string `json:"name"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaDeleteRequest represents a request to delete a model
|
||||
type OllamaDeleteRequest struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// OllamaCopyRequest represents a request to copy a model
|
||||
type OllamaCopyRequest struct {
|
||||
Source string `json:"source"`
|
||||
Destination string `json:"destination"`
|
||||
}
|
||||
10
go.mod
10
go.mod
@@ -42,6 +42,7 @@ require (
|
||||
github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8
|
||||
github.com/mudler/xlog v0.0.6
|
||||
github.com/nats-io/nats.go v1.50.0
|
||||
github.com/ollama/ollama v0.20.4
|
||||
github.com/onsi/ginkgo/v2 v2.28.1
|
||||
github.com/onsi/gomega v1.39.1
|
||||
github.com/openai/openai-go/v3 v3.26.0
|
||||
@@ -89,15 +90,18 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
||||
github.com/nats-io/nkeys v0.4.15 // indirect
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/tmc/langchaingo v0.1.14 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
sigs.k8s.io/yaml v1.6.0 // indirect
|
||||
)
|
||||
|
||||
@@ -206,7 +210,7 @@ require (
|
||||
github.com/BurntSushi/toml v1.5.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.8.0 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
@@ -405,7 +409,7 @@ require (
|
||||
github.com/opentracing/opentracing-go v1.2.0 // indirect
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
|
||||
github.com/peterbourgon/diskv v2.0.1+incompatible // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.2 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.8 // indirect
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.7 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
||||
21
go.sum
21
go.sum
@@ -114,6 +114,8 @@ github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWp
|
||||
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o=
|
||||
github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
@@ -161,6 +163,8 @@ github.com/blevesearch/zapx/v16 v16.2.8/go.mod h1:murSoCJPCk25MqURrcJaBQ1RekuqSC
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g=
|
||||
github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||
github.com/c-robinson/iplib v1.0.8 h1:exDRViDyL9UBLcfmlxxkY5odWX5092nPsQIykHXhIn4=
|
||||
@@ -179,8 +183,8 @@ github.com/charmbracelet/glamour v0.10.0 h1:MtZvfwsYCx8jEPFJm3rIBFIMZUfUJ765oX8V
|
||||
github.com/charmbracelet/glamour v0.10.0/go.mod h1:f+uf+I/ChNmqo087elLnVdCiVgjSKWuXa/l6NU2ndYk=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
|
||||
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
|
||||
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
|
||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240806155701-69247e0abc2a h1:G99klV19u0QnhiizODirwVksQB91TJKV/UaTnACcG30=
|
||||
@@ -642,8 +646,8 @@ github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
|
||||
@@ -721,8 +725,6 @@ github.com/mudler/go-processmanager v0.1.0 h1:fcSKgF9U/a1Z7KofAFeZnke5YseadCI5Gq
|
||||
github.com/mudler/go-processmanager v0.1.0/go.mod h1:h6kmHUZeafr+k5hRYpGLMzJFH4hItHffgpRo2QIkP+o=
|
||||
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b h1:XeAnOEOOSKMfS5XNGpRTltQgjKCinho0V4uAhrgxN7Q=
|
||||
github.com/mudler/localrecall v0.5.9-0.20260321005011-810084e9369b/go.mod h1:xuPtgL9zUyiQLmspYzO3kaboYrGbWmwi8BQPt1aCAcs=
|
||||
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2 h1:+WHsL/j6EWOMUiMVIOJNKOwSKiQt/qDPc9fePCf87fA=
|
||||
github.com/mudler/memory v0.0.0-20251216220809-d1256471a6c2/go.mod h1:EA8Ashhd56o32qN7ouPKFSRUs/Z+LrRCF4v6R2Oarm8=
|
||||
github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8 h1:Ry8RiWy8fZ6Ff4E7dPmjRsBrnHOnPeOOj2LhCgyjQu0=
|
||||
github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8/go.mod h1:EA8Ashhd56o32qN7ouPKFSRUs/Z+LrRCF4v6R2Oarm8=
|
||||
github.com/mudler/skillserver v0.0.6 h1:ixz6wUekLdTmbnpAavCkTydDF6UdXAG3ncYufSPK9G0=
|
||||
@@ -777,6 +779,8 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/ollama/ollama v0.20.4 h1:XXquZkzAptOoAzNHAyKQOhiShoDFMfn3Yp56C7Vfsjs=
|
||||
github.com/ollama/ollama v0.20.4/go.mod h1:tCX4IMV8DHjl3zY0THxuEkpWDZSOchJpzTuLACpMwFw=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.28.1 h1:S4hj+HbZp40fNKuLUQOYLDgZLwNUVn19N3Atb98NCyI=
|
||||
@@ -809,8 +813,9 @@ github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rK
|
||||
github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY=
|
||||
github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
|
||||
github.com/phpdave11/gofpdi v1.0.7/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM=
|
||||
github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0=
|
||||
github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk=
|
||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||
@@ -1065,6 +1070,8 @@ github.com/warpfork/go-wish v0.0.0-20220906213052-39a1cc7a02d0 h1:GDDkbFiaK8jsSD
|
||||
github.com/warpfork/go-wish v0.0.0-20220906213052-39a1cc7a02d0/go.mod h1:x6AKhvSSexNrVSrViXSHUEbICjmGXhtgABaHIySUSGw=
|
||||
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k=
|
||||
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1/go.mod h1:8UvriyWtv5Q5EOgjHaSseUEdkQfvwFv1I/In/O2M9gc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
|
||||
349
tests/e2e/e2e_ollama_test.go
Normal file
349
tests/e2e/e2e_ollama_test.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package e2e_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Ollama API E2E test", Label("Ollama"), func() {
|
||||
var client *api.Client
|
||||
|
||||
Context("API with Ollama client", func() {
|
||||
BeforeEach(func() {
|
||||
u, err := url.Parse(ollamaBaseURL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client = api.NewClient(u, http.DefaultClient)
|
||||
})
|
||||
|
||||
Context("Model management", func() {
|
||||
It("lists available models via /api/tags", func() {
|
||||
resp, err := client.List(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Models).ToNot(BeEmpty())
|
||||
|
||||
// Find mock-model and validate its fields
|
||||
var found *api.ListModelResponse
|
||||
for i, m := range resp.Models {
|
||||
if m.Name == "mock-model:latest" {
|
||||
found = &resp.Models[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).ToNot(BeNil(), "mock-model:latest should be in the list")
|
||||
Expect(found.Model).To(Equal("mock-model:latest"))
|
||||
Expect(found.Digest).ToNot(BeEmpty())
|
||||
Expect(found.ModifiedAt).ToNot(BeZero())
|
||||
})
|
||||
|
||||
It("shows model details via /api/show", func() {
|
||||
resp, err := client.Show(context.TODO(), &api.ShowRequest{
|
||||
Name: "mock-model",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Modelfile).To(ContainSubstring("FROM"))
|
||||
Expect(resp.Details.Format).To(Equal("gguf"))
|
||||
})
|
||||
|
||||
It("returns 404 for unknown model in /api/show", func() {
|
||||
_, err := client.Show(context.TODO(), &api.ShowRequest{
|
||||
Name: "nonexistent-model",
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns version via /api/version", func() {
|
||||
version, err := client.Version(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(version).ToNot(BeEmpty())
|
||||
// Should be a semver-like string
|
||||
Expect(version).To(MatchRegexp(`^\d+\.\d+\.\d+`))
|
||||
})
|
||||
|
||||
It("responds to HEAD /api/version", func() {
|
||||
req, err := http.NewRequest("HEAD", fmt.Sprintf("%s/api/version", ollamaBaseURL), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
})
|
||||
|
||||
It("responds to HEAD /api/tags", func() {
|
||||
req, err := http.NewRequest("HEAD", fmt.Sprintf("%s/api/tags", ollamaBaseURL), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
})
|
||||
|
||||
// Heartbeat (HEAD /) requires the OllamaAPIRootEndpoint CLI flag
|
||||
// which is not enabled in the default test setup.
|
||||
|
||||
It("lists running models via /api/ps after a model has been loaded", func() {
|
||||
// First, trigger a chat to ensure the model is loaded
|
||||
stream := false
|
||||
err := client.Chat(context.TODO(), &api.ChatRequest{
|
||||
Model: "mock-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "ping"}},
|
||||
Stream: &stream,
|
||||
}, func(resp api.ChatResponse) error { return nil })
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Now check ps
|
||||
resp, err := client.ListRunning(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Models).ToNot(BeEmpty(), "at least one model should be loaded after chat")
|
||||
|
||||
var found bool
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == "mock-model:latest" {
|
||||
found = true
|
||||
Expect(m.Digest).ToNot(BeEmpty())
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).To(BeTrue(), "mock-model should appear in running models")
|
||||
})
|
||||
})
|
||||
|
||||
Context("Chat endpoint", func() {
|
||||
It("generates a non-streaming chat response with valid fields", func() {
|
||||
stream := false
|
||||
var finalResp api.ChatResponse
|
||||
|
||||
err := client.Chat(context.TODO(), &api.ChatRequest{
|
||||
Model: "mock-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "How much is 2+2?"},
|
||||
},
|
||||
Stream: &stream,
|
||||
}, func(resp api.ChatResponse) error {
|
||||
finalResp = resp
|
||||
return nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(finalResp.Done).To(BeTrue())
|
||||
Expect(finalResp.DoneReason).To(Equal("stop"))
|
||||
Expect(finalResp.Message.Role).To(Equal("assistant"))
|
||||
Expect(finalResp.Message.Content).ToNot(BeEmpty())
|
||||
Expect(finalResp.Model).To(Equal("mock-model"))
|
||||
Expect(finalResp.CreatedAt).ToNot(BeZero())
|
||||
Expect(finalResp.TotalDuration).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("streams tokens incrementally", func() {
|
||||
stream := true
|
||||
var chunks []api.ChatResponse
|
||||
|
||||
err := client.Chat(context.TODO(), &api.ChatRequest{
|
||||
Model: "mock-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Say hello"},
|
||||
},
|
||||
Stream: &stream,
|
||||
}, func(resp api.ChatResponse) error {
|
||||
chunks = append(chunks, resp)
|
||||
return nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(chunks)).To(BeNumerically(">=", 2), "should have at least one content chunk + done chunk")
|
||||
|
||||
// Last chunk must be the done signal
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
Expect(lastChunk.Done).To(BeTrue())
|
||||
Expect(lastChunk.DoneReason).To(Equal("stop"))
|
||||
Expect(lastChunk.TotalDuration).To(BeNumerically(">", 0))
|
||||
|
||||
// Non-final chunks should carry content
|
||||
hasContent := false
|
||||
for _, c := range chunks[:len(chunks)-1] {
|
||||
if c.Message.Content != "" {
|
||||
hasContent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(hasContent).To(BeTrue(), "intermediate streaming chunks should carry token content")
|
||||
})
|
||||
|
||||
It("handles multi-turn conversation with system prompt", func() {
|
||||
stream := false
|
||||
var finalResp api.ChatResponse
|
||||
|
||||
err := client.Chat(context.TODO(), &api.ChatRequest{
|
||||
Model: "mock-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is Go?"},
|
||||
{Role: "assistant", Content: "Go is a programming language."},
|
||||
{Role: "user", Content: "Who created it?"},
|
||||
},
|
||||
Stream: &stream,
|
||||
}, func(resp api.ChatResponse) error {
|
||||
finalResp = resp
|
||||
return nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(finalResp.Done).To(BeTrue())
|
||||
Expect(finalResp.Message.Content).ToNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Generate endpoint", func() {
|
||||
It("generates a non-streaming response with valid fields", func() {
|
||||
stream := false
|
||||
var finalResp api.GenerateResponse
|
||||
|
||||
err := client.Generate(context.TODO(), &api.GenerateRequest{
|
||||
Model: "mock-model",
|
||||
Prompt: "Once upon a time",
|
||||
Stream: &stream,
|
||||
}, func(resp api.GenerateResponse) error {
|
||||
finalResp = resp
|
||||
return nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(finalResp.Done).To(BeTrue())
|
||||
Expect(finalResp.DoneReason).To(Equal("stop"))
|
||||
Expect(finalResp.Response).ToNot(BeEmpty())
|
||||
Expect(finalResp.Model).To(Equal("mock-model"))
|
||||
Expect(finalResp.CreatedAt).ToNot(BeZero())
|
||||
Expect(finalResp.TotalDuration).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("streams tokens incrementally", func() {
|
||||
stream := true
|
||||
var chunks []api.GenerateResponse
|
||||
|
||||
err := client.Generate(context.TODO(), &api.GenerateRequest{
|
||||
Model: "mock-model",
|
||||
Prompt: "Tell me a story",
|
||||
Stream: &stream,
|
||||
}, func(resp api.GenerateResponse) error {
|
||||
chunks = append(chunks, resp)
|
||||
return nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(chunks)).To(BeNumerically(">=", 2))
|
||||
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
Expect(lastChunk.Done).To(BeTrue())
|
||||
Expect(lastChunk.DoneReason).To(Equal("stop"))
|
||||
|
||||
// Check that intermediate chunks have response text
|
||||
hasContent := false
|
||||
for _, c := range chunks[:len(chunks)-1] {
|
||||
if c.Response != "" {
|
||||
hasContent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(hasContent).To(BeTrue(), "intermediate streaming chunks should carry token content")
|
||||
})
|
||||
|
||||
It("returns load response for empty prompt", func() {
|
||||
stream := false
|
||||
var finalResp api.GenerateResponse
|
||||
|
||||
err := client.Generate(context.TODO(), &api.GenerateRequest{
|
||||
Model: "mock-model",
|
||||
Prompt: "",
|
||||
Stream: &stream,
|
||||
}, func(resp api.GenerateResponse) error {
|
||||
finalResp = resp
|
||||
return nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(finalResp.Done).To(BeTrue())
|
||||
Expect(finalResp.DoneReason).To(Equal("load"))
|
||||
})
|
||||
|
||||
It("supports system prompt in generate", func() {
|
||||
stream := false
|
||||
var finalResp api.GenerateResponse
|
||||
|
||||
err := client.Generate(context.TODO(), &api.GenerateRequest{
|
||||
Model: "mock-model",
|
||||
Prompt: "Hello",
|
||||
System: "You are a pirate.",
|
||||
Stream: &stream,
|
||||
}, func(resp api.GenerateResponse) error {
|
||||
finalResp = resp
|
||||
return nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(finalResp.Done).To(BeTrue())
|
||||
Expect(finalResp.Response).ToNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Embed endpoint", func() {
|
||||
It("generates embeddings for a single input via /api/embed", func() {
|
||||
resp, err := client.Embed(context.TODO(), &api.EmbedRequest{
|
||||
Model: "mock-model",
|
||||
Input: "Hello, world!",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.Embeddings).To(HaveLen(1))
|
||||
Expect(len(resp.Embeddings[0])).To(BeNumerically(">", 0), "embedding vector should have dimensions")
|
||||
Expect(resp.Model).To(Equal("mock-model"))
|
||||
})
|
||||
|
||||
It("generates embeddings via the legacy /api/embeddings alias", func() {
|
||||
// The ollama client uses /api/embed, so test the legacy endpoint with raw HTTP
|
||||
body := map[string]any{
|
||||
"model": "mock-model",
|
||||
"input": "test input",
|
||||
}
|
||||
bodyJSON, err := json.Marshal(body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/api/embeddings", ollamaBaseURL),
|
||||
"application/json",
|
||||
bytes.NewReader(bodyJSON),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
var result map[string]any
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(json.Unmarshal(respBody, &result)).To(Succeed())
|
||||
Expect(result).To(HaveKey("embeddings"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Error handling", func() {
|
||||
It("returns error for chat with unknown model", func() {
|
||||
stream := false
|
||||
err := client.Chat(context.TODO(), &api.ChatRequest{
|
||||
Model: "nonexistent-model-xyz",
|
||||
Messages: []api.Message{{Role: "user", Content: "hi"}},
|
||||
Stream: &stream,
|
||||
}, func(resp api.ChatResponse) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns error for generate with unknown model", func() {
|
||||
stream := false
|
||||
err := client.Generate(context.TODO(), &api.GenerateRequest{
|
||||
Model: "nonexistent-model-xyz",
|
||||
Prompt: "hi",
|
||||
Stream: &stream,
|
||||
}, func(resp api.GenerateResponse) error { return nil })
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
var (
|
||||
anthropicBaseURL string
|
||||
ollamaBaseURL string
|
||||
tmpDir string
|
||||
backendPath string
|
||||
modelsPath string
|
||||
@@ -245,6 +246,8 @@ var _ = BeforeSuite(func() {
|
||||
apiURL = fmt.Sprintf("http://127.0.0.1:%d/v1", apiPort)
|
||||
// Anthropic SDK appends /v1/messages to base URL; use base without /v1 so requests go to /v1/messages
|
||||
anthropicBaseURL = fmt.Sprintf("http://127.0.0.1:%d", apiPort)
|
||||
// Ollama client uses base URL directly
|
||||
ollamaBaseURL = fmt.Sprintf("http://127.0.0.1:%d", apiPort)
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
|
||||
Reference in New Issue
Block a user