mirror of
https://github.com/ollama/ollama.git
synced 2026-02-06 13:43:39 -05:00
Compare commits
1 Commits
main
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5a2849d1f |
@@ -6,6 +6,7 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
||||
@@ -194,20 +194,6 @@ func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[st
|
||||
return nil
|
||||
}
|
||||
|
||||
// showOrPull checks if a model exists via client.Show and offers to pull it if not found.
|
||||
func showOrPull(ctx context.Context, client *api.Client, model string) error {
|
||||
if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
|
||||
return nil
|
||||
}
|
||||
if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return pullModel(ctx, client, model)
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -411,11 +397,8 @@ Examples:
|
||||
|
||||
// Validate --model flag if provided
|
||||
if modelFlag != "" {
|
||||
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
|
||||
return fmt.Errorf("model %q not found", modelFlag)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,11 +424,9 @@ Examples:
|
||||
|
||||
// Validate saved model still exists
|
||||
if model != "" && modelFlag == "" {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||
if err := showOrPull(cmd.Context(), client, model); err != nil {
|
||||
model = ""
|
||||
}
|
||||
model = ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -462,13 +443,6 @@ Examples:
|
||||
existingAliases = aliases
|
||||
}
|
||||
|
||||
// Ensure cloud models are authenticated
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Sync aliases and save
|
||||
if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
@@ -493,11 +467,8 @@ Examples:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := showOrPull(cmd.Context(), client, modelFlag); err != nil {
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if _, err := client.Show(cmd.Context(), &api.ShowRequest{Name: modelFlag}); err != nil {
|
||||
return fmt.Errorf("model %q not found", modelFlag)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -679,7 +650,7 @@ func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,17 +2,12 @@ package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -544,136 +539,3 @@ func TestAliasConfigurerInterface(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelExists(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"model":"test-model"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := showOrPull(context.Background(), client, "test-model")
|
||||
if err != nil {
|
||||
t.Errorf("showOrPull should return nil when model exists, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, `{"error":"model not found"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
// confirmPrompt will fail in test (no terminal), so showOrPull should return an error
|
||||
err := showOrPull(context.Background(), client, "missing-model")
|
||||
if err == nil {
|
||||
t.Error("showOrPull should return error when model not found and no terminal available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
|
||||
var receivedModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
|
||||
receivedModel = req.Model
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"model":"%s"}`, receivedModel)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
_ = showOrPull(context.Background(), client, "qwen3:8b")
|
||||
if receivedModel != "qwen3:8b" {
|
||||
t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_NoCloudModels(t *testing.T) {
|
||||
// ensureAuth should be a no-op when no cloud models are selected
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("no API calls expected when no cloud models selected")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
|
||||
if err != nil {
|
||||
t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
|
||||
// ensureAuth should only care about models in cloudModels map
|
||||
var whoamiCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/me" {
|
||||
whoamiCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"name":"testuser"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cloudModels := map[string]bool{"cloud-model:cloud": true}
|
||||
selected := []string{"cloud-model:cloud", "local-model"}
|
||||
|
||||
err := ensureAuth(context.Background(), client, cloudModels, selected)
|
||||
if err != nil {
|
||||
t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
|
||||
}
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for cloud model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
|
||||
var whoamiCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/me" {
|
||||
whoamiCalled = true
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
// cloudModels has entries but none are in selected
|
||||
cloudModels := map[string]bool{"cloud-model:cloud": true}
|
||||
selected := []string{"local-model"}
|
||||
|
||||
err := ensureAuth(context.Background(), client, cloudModels, selected)
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error, got: %v", err)
|
||||
}
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called when no cloud models are selected")
|
||||
}
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -24,6 +24,7 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/klauspost/compress v1.18.3
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
||||
4
go.sum
4
go.sum
@@ -106,7 +106,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
|
||||
@@ -134,8 +133,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
|
||||
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -1238,3 +1239,102 @@ func TestImageEditsMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func zstdCompress(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
w, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestResponsesMiddlewareZstd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
useZstd bool
|
||||
oversized bool
|
||||
wantCode int
|
||||
wantModel string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "plain JSON",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd compressed",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd over max decompressed size",
|
||||
oversized: true,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
var bodyReader io.Reader
|
||||
if tt.oversized {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
|
||||
} else if tt.useZstd {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
|
||||
} else {
|
||||
bodyReader = strings.NewReader(tt.body)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.useZstd || tt.oversized {
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.wantCode {
|
||||
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("expected captured request, got nil")
|
||||
}
|
||||
if capturedRequest.Model != tt.wantModel {
|
||||
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
|
||||
}
|
||||
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
|
||||
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user