mirror of
https://github.com/ollama/ollama.git
synced 2026-02-04 20:53:28 -05:00
Compare commits
3 Commits
v0.15.5-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c61023f554 | ||
|
|
d25535c3f3 | ||
|
|
c323161f24 |
13
cmd/cmd.go
13
cmd/cmd.go
@@ -367,14 +367,25 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
} else if info.RemoteHost != "" {
|
||||
// Cloud model, no need to load/unload
|
||||
|
||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
|
||||
// Check if user is signed in for ollama.com cloud models
|
||||
if isCloud {
|
||||
if _, err := client.Whoami(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
||||
if isCloud {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
101
cmd/cmd_test.go
101
cmd/cmd_test.go
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -1659,3 +1660,103 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
t.Error("Copy Think should not be affected by original modification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteHost string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
remoteHost: "https://ollama.com",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
remoteHost: "https://ollama.com",
|
||||
whoamiStatus: http.StatusUnauthorized,
|
||||
whoamiResp: map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
},
|
||||
expectedError: "unauthorized",
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
remoteHost: "https://other-remote.com",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
whoamiCalled := false
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
RemoteModel: "test-model",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
case "/api/me":
|
||||
whoamiCalled = true
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.whoamiStatus)
|
||||
if tt.whoamiResp != nil {
|
||||
if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
opts := &runOptions{
|
||||
Model: "test-cloud-model",
|
||||
ShowConnect: false,
|
||||
}
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
} else {
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called for non-ollama.com remote")
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||
} else {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
// TestNumPredict verifies that when num_predict is set, the model generates
|
||||
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
||||
func TestNumPredict(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
|
||||
t.Fatalf("failed to pull model: %v", err)
|
||||
}
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "qwen3:0.6b",
|
||||
Prompt: "Write a long story.",
|
||||
Stream: &stream,
|
||||
Logprobs: true,
|
||||
Options: map[string]any{
|
||||
"num_predict": 10,
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
logprobCount := 0
|
||||
var finalResponse api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
logprobCount += len(resp.Logprobs)
|
||||
if resp.Done {
|
||||
finalResponse = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %v", err)
|
||||
}
|
||||
|
||||
if logprobCount != 10 {
|
||||
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
|
||||
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,6 +175,7 @@ type Tensor interface {
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
SigmoidOut(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
@@ -1468,6 +1468,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
|
||||
@@ -135,7 +135,7 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
||||
// Apply shared expert gating
|
||||
if mlp.SharedGateInp != nil {
|
||||
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
|
||||
sharedGateVal = sharedGateVal.Sigmoid(ctx)
|
||||
sharedGateVal = sharedGateVal.SigmoidOut(ctx)
|
||||
// Broadcast gate to match dimensions
|
||||
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
|
||||
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
|
||||
|
||||
@@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
@@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
@@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
seq.numPredicted++
|
||||
if seq.numPredicted == 1 {
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
@@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
|
||||
Reference in New Issue
Block a user