mirror of
https://github.com/ollama/ollama.git
synced 2026-02-05 21:23:43 -05:00
Compare commits
4 Commits
parth-laun
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f92a82db15 | ||
|
|
c61023f554 | ||
|
|
d25535c3f3 | ||
|
|
c323161f24 |
@@ -1,13 +1,12 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { Model } from "@/gotypes";
|
||||
import { getModels } from "@/api";
|
||||
import { mergeModels } from "@/utils/mergeModels";
|
||||
import { useSettings } from "./useSettings";
|
||||
import { useMemo } from "react";
|
||||
|
||||
const DEFAULT_MODEL = "gemma3:4b";
|
||||
|
||||
export function useModels(searchQuery = "") {
|
||||
const { settings } = useSettings();
|
||||
const localQuery = useQuery<Model[], Error>({
|
||||
const query = useQuery<Model[], Error>({
|
||||
queryKey: ["models", searchQuery],
|
||||
queryFn: () => getModels(searchQuery),
|
||||
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
|
||||
@@ -19,33 +18,18 @@ export function useModels(searchQuery = "") {
|
||||
refetchIntervalInBackground: true,
|
||||
});
|
||||
|
||||
const allModels = useMemo(() => {
|
||||
const models = mergeModels(localQuery.data || [], settings.airplaneMode);
|
||||
|
||||
if (searchQuery && searchQuery.trim()) {
|
||||
const query = searchQuery.toLowerCase().trim();
|
||||
const filteredModels = models.filter((model) =>
|
||||
model.model.toLowerCase().includes(query),
|
||||
);
|
||||
|
||||
const seen = new Set<string>();
|
||||
return filteredModels.filter((model) => {
|
||||
const currentModel = model.model.toLowerCase();
|
||||
if (seen.has(currentModel)) {
|
||||
return false;
|
||||
}
|
||||
seen.add(currentModel);
|
||||
return true;
|
||||
});
|
||||
const models = useMemo(() => {
|
||||
const data = query.data || [];
|
||||
if (data.length === 0) {
|
||||
return [new Model({ model: DEFAULT_MODEL })];
|
||||
}
|
||||
|
||||
return models;
|
||||
}, [localQuery.data, searchQuery, settings.airplaneMode]);
|
||||
return data;
|
||||
}, [query.data]);
|
||||
|
||||
return {
|
||||
...localQuery,
|
||||
data: allModels,
|
||||
isLoading: localQuery.isLoading,
|
||||
...query,
|
||||
data: models,
|
||||
isLoading: query.isLoading,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import { useModels } from "./useModels";
|
||||
import { useChat } from "./useChats";
|
||||
import { useSettings } from "./useSettings.ts";
|
||||
import { Model } from "@/gotypes";
|
||||
import { FEATURED_MODELS } from "@/utils/mergeModels";
|
||||
import { getTotalVRAM } from "@/utils/vram.ts";
|
||||
import { getInferenceCompute } from "@/api";
|
||||
|
||||
@@ -46,77 +45,13 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
const restoredChatRef = useRef<string | null>(null);
|
||||
|
||||
const selectedModel: Model | null = useMemo(() => {
|
||||
// if airplane mode is on and selected model ends with cloud,
|
||||
// switch to recommended default model
|
||||
if (settings.airplaneMode && settings.selectedModel?.endsWith("cloud")) {
|
||||
return (
|
||||
models.find((m) => m.model === recommendedModel) ||
|
||||
models.find((m) => m.isCloud) ||
|
||||
models.find((m) => m.digest === undefined || m.digest === "") ||
|
||||
models[0] ||
|
||||
null
|
||||
);
|
||||
}
|
||||
|
||||
// Migration logic: if turboEnabled is true and selectedModel is a base model,
|
||||
// migrate to the cloud version and disable turboEnabled permanently
|
||||
// TODO: remove this logic in a future release
|
||||
const baseModelsToMigrate = [
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
"deepseek-v3.1:671b",
|
||||
"qwen3-coder:480b",
|
||||
];
|
||||
const shouldMigrate =
|
||||
!settings.airplaneMode &&
|
||||
settings.turboEnabled &&
|
||||
baseModelsToMigrate.includes(settings.selectedModel);
|
||||
|
||||
if (shouldMigrate) {
|
||||
const cloudModel = `${settings.selectedModel}cloud`;
|
||||
return (
|
||||
models.find((m) => m.model === cloudModel) ||
|
||||
new Model({
|
||||
model: cloudModel,
|
||||
cloud: true,
|
||||
ollama_host: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
models.find((m) => m.model === settings.selectedModel) ||
|
||||
(settings.selectedModel &&
|
||||
new Model({
|
||||
model: settings.selectedModel,
|
||||
cloud: FEATURED_MODELS.some(
|
||||
(f) => f.endsWith("cloud") && f === settings.selectedModel,
|
||||
),
|
||||
ollama_host: false,
|
||||
})) ||
|
||||
new Model({ model: settings.selectedModel })) ||
|
||||
null
|
||||
);
|
||||
}, [models, settings.selectedModel, settings.airplaneMode, recommendedModel]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedModel) return;
|
||||
|
||||
if (
|
||||
settings.airplaneMode &&
|
||||
settings.selectedModel?.endsWith("cloud") &&
|
||||
selectedModel.model !== settings.selectedModel
|
||||
) {
|
||||
setSettings({ SelectedModel: selectedModel.model });
|
||||
}
|
||||
|
||||
if (
|
||||
!settings.airplaneMode &&
|
||||
settings.turboEnabled &&
|
||||
selectedModel.model !== settings.selectedModel
|
||||
) {
|
||||
setSettings({ SelectedModel: selectedModel.model, TurboEnabled: false });
|
||||
}
|
||||
}, [selectedModel, settings.airplaneMode, settings.selectedModel]);
|
||||
}, [models, settings.selectedModel]);
|
||||
|
||||
// Set model from chat history when chat data loads
|
||||
useEffect(() => {
|
||||
@@ -169,8 +104,6 @@ export function useSelectedModel(currentChatId?: string, searchQuery?: string) {
|
||||
|
||||
const defaultModel =
|
||||
models.find((m) => m.model === recommendedModel) ||
|
||||
models.find((m) => m.isCloud()) ||
|
||||
models.find((m) => m.digest === undefined || m.digest === "") ||
|
||||
models[0];
|
||||
|
||||
if (defaultModel) {
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { Model } from "@/gotypes";
|
||||
import { mergeModels, FEATURED_MODELS } from "@/utils/mergeModels";
|
||||
import "@/api";
|
||||
|
||||
describe("Model merging logic", () => {
|
||||
it("should handle cloud models with -cloud suffix", () => {
|
||||
const localModels: Model[] = [
|
||||
new Model({ model: "gpt-oss:120b-cloud" }),
|
||||
new Model({ model: "llama3:latest" }),
|
||||
new Model({ model: "mistral:latest" }),
|
||||
];
|
||||
|
||||
const merged = mergeModels(localModels);
|
||||
|
||||
// First verify cloud models are first and in FEATURED_MODELS order
|
||||
const cloudModels = FEATURED_MODELS.filter((m: string) =>
|
||||
m.endsWith("cloud"),
|
||||
);
|
||||
for (let i = 0; i < cloudModels.length; i++) {
|
||||
expect(merged[i].model).toBe(cloudModels[i]);
|
||||
expect(merged[i].isCloud()).toBe(true);
|
||||
}
|
||||
|
||||
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
|
||||
const nonCloudFeatured = FEATURED_MODELS.filter(
|
||||
(m: string) => !m.endsWith("cloud"),
|
||||
);
|
||||
for (let i = 0; i < nonCloudFeatured.length; i++) {
|
||||
const model = merged[i + cloudModels.length];
|
||||
expect(model.model).toBe(nonCloudFeatured[i]);
|
||||
expect(model.isCloud()).toBe(false);
|
||||
}
|
||||
|
||||
// Verify local models are preserved and come after featured models
|
||||
const featuredCount = FEATURED_MODELS.length;
|
||||
expect(merged[featuredCount].model).toBe("llama3:latest");
|
||||
expect(merged[featuredCount + 1].model).toBe("mistral:latest");
|
||||
|
||||
// Length should be exactly featured models plus our local models
|
||||
expect(merged.length).toBe(FEATURED_MODELS.length + 2);
|
||||
});
|
||||
|
||||
it("should hide cloud models in airplane mode", () => {
|
||||
const localModels: Model[] = [
|
||||
new Model({ model: "gpt-oss:120b-cloud" }),
|
||||
new Model({ model: "llama3:latest" }),
|
||||
new Model({ model: "mistral:latest" }),
|
||||
];
|
||||
|
||||
const merged = mergeModels(localModels, true); // airplane mode = true
|
||||
|
||||
// No cloud models should be present
|
||||
const cloudModels = merged.filter((m) => m.isCloud());
|
||||
expect(cloudModels.length).toBe(0);
|
||||
|
||||
// Should have non-cloud featured models
|
||||
const nonCloudFeatured = FEATURED_MODELS.filter(
|
||||
(m) => !m.endsWith("cloud"),
|
||||
);
|
||||
for (let i = 0; i < nonCloudFeatured.length; i++) {
|
||||
const model = merged[i];
|
||||
expect(model.model).toBe(nonCloudFeatured[i]);
|
||||
expect(model.isCloud()).toBe(false);
|
||||
}
|
||||
|
||||
// Local models should be preserved
|
||||
const featuredCount = nonCloudFeatured.length;
|
||||
expect(merged[featuredCount].model).toBe("llama3:latest");
|
||||
expect(merged[featuredCount + 1].model).toBe("mistral:latest");
|
||||
});
|
||||
|
||||
it("should handle empty input", () => {
|
||||
const merged = mergeModels([]);
|
||||
|
||||
// First verify cloud models are first and in FEATURED_MODELS order
|
||||
const cloudModels = FEATURED_MODELS.filter((m) => m.endsWith("cloud"));
|
||||
for (let i = 0; i < cloudModels.length; i++) {
|
||||
expect(merged[i].model).toBe(cloudModels[i]);
|
||||
expect(merged[i].isCloud()).toBe(true);
|
||||
}
|
||||
|
||||
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
|
||||
const nonCloudFeatured = FEATURED_MODELS.filter(
|
||||
(m) => !m.endsWith("cloud"),
|
||||
);
|
||||
for (let i = 0; i < nonCloudFeatured.length; i++) {
|
||||
const model = merged[i + cloudModels.length];
|
||||
expect(model.model).toBe(nonCloudFeatured[i]);
|
||||
expect(model.isCloud()).toBe(false);
|
||||
}
|
||||
|
||||
// Length should be exactly FEATURED_MODELS length
|
||||
expect(merged.length).toBe(FEATURED_MODELS.length);
|
||||
});
|
||||
|
||||
it("should sort models correctly", () => {
|
||||
const localModels: Model[] = [
|
||||
new Model({ model: "zephyr:latest" }),
|
||||
new Model({ model: "alpha:latest" }),
|
||||
new Model({ model: "gpt-oss:120b-cloud" }),
|
||||
];
|
||||
|
||||
const merged = mergeModels(localModels);
|
||||
|
||||
// First verify cloud models are first and in FEATURED_MODELS order
|
||||
const cloudModels = FEATURED_MODELS.filter((m) => m.endsWith("cloud"));
|
||||
for (let i = 0; i < cloudModels.length; i++) {
|
||||
expect(merged[i].model).toBe(cloudModels[i]);
|
||||
expect(merged[i].isCloud()).toBe(true);
|
||||
}
|
||||
|
||||
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
|
||||
const nonCloudFeatured = FEATURED_MODELS.filter(
|
||||
(m) => !m.endsWith("cloud"),
|
||||
);
|
||||
for (let i = 0; i < nonCloudFeatured.length; i++) {
|
||||
const model = merged[i + cloudModels.length];
|
||||
expect(model.model).toBe(nonCloudFeatured[i]);
|
||||
expect(model.isCloud()).toBe(false);
|
||||
}
|
||||
|
||||
// Non-featured local models should be at the end in alphabetical order
|
||||
const featuredCount = FEATURED_MODELS.length;
|
||||
expect(merged[featuredCount].model).toBe("alpha:latest");
|
||||
expect(merged[featuredCount + 1].model).toBe("zephyr:latest");
|
||||
});
|
||||
});
|
||||
@@ -1,101 +0,0 @@
|
||||
import { Model } from "@/gotypes";
|
||||
|
||||
// Featured models list (in priority order)
|
||||
export const FEATURED_MODELS = [
|
||||
"gpt-oss:120b-cloud",
|
||||
"gpt-oss:20b-cloud",
|
||||
"deepseek-v3.1:671b-cloud",
|
||||
"qwen3-coder:480b-cloud",
|
||||
"qwen3-vl:235b-cloud",
|
||||
"minimax-m2:cloud",
|
||||
"glm-4.6:cloud",
|
||||
"gpt-oss:120b",
|
||||
"gpt-oss:20b",
|
||||
"gemma3:27b",
|
||||
"gemma3:12b",
|
||||
"gemma3:4b",
|
||||
"gemma3:1b",
|
||||
"deepseek-r1:8b",
|
||||
"qwen3-coder:30b",
|
||||
"qwen3-vl:30b",
|
||||
"qwen3-vl:8b",
|
||||
"qwen3-vl:4b",
|
||||
"qwen3:30b",
|
||||
"qwen3:8b",
|
||||
"qwen3:4b",
|
||||
];
|
||||
|
||||
function alphabeticalSort(a: Model, b: Model): number {
|
||||
return a.model.toLowerCase().localeCompare(b.model.toLowerCase());
|
||||
}
|
||||
|
||||
//Merges models, sorting cloud models first, then other models
|
||||
export function mergeModels(
|
||||
localModels: Model[],
|
||||
airplaneMode: boolean = false,
|
||||
): Model[] {
|
||||
const allModels = (localModels || []).map((model) => model);
|
||||
|
||||
// 1. Get cloud models from local models and featured list
|
||||
const cloudModels = [...allModels.filter((m) => m.isCloud())];
|
||||
|
||||
// Add any cloud models from FEATURED_MODELS that aren't in local models
|
||||
FEATURED_MODELS.filter((f) => f.endsWith("cloud")).forEach((cloudModel) => {
|
||||
if (!cloudModels.some((m) => m.model === cloudModel)) {
|
||||
cloudModels.push(new Model({ model: cloudModel }));
|
||||
}
|
||||
});
|
||||
|
||||
// 2. Get other featured models (non-cloud)
|
||||
const featuredModels = FEATURED_MODELS.filter(
|
||||
(f) => !f.endsWith("cloud"),
|
||||
).map((model) => {
|
||||
// Check if this model exists in local models
|
||||
const localMatch = allModels.find(
|
||||
(m) => m.model.toLowerCase() === model.toLowerCase(),
|
||||
);
|
||||
|
||||
if (localMatch) return localMatch;
|
||||
|
||||
return new Model({
|
||||
model,
|
||||
});
|
||||
});
|
||||
|
||||
// 3. Get remaining local models that aren't featured and aren't cloud models
|
||||
const remainingModels = allModels.filter(
|
||||
(model) =>
|
||||
!model.isCloud() &&
|
||||
!FEATURED_MODELS.some(
|
||||
(f) => f.toLowerCase() === model.model.toLowerCase(),
|
||||
),
|
||||
);
|
||||
|
||||
cloudModels.sort((a, b) => {
|
||||
const aIndex = FEATURED_MODELS.indexOf(a.model);
|
||||
const bIndex = FEATURED_MODELS.indexOf(b.model);
|
||||
|
||||
// If both are featured, sort by their position in FEATURED_MODELS
|
||||
if (aIndex !== -1 && bIndex !== -1) {
|
||||
return aIndex - bIndex;
|
||||
}
|
||||
|
||||
// If only one is featured, featured model comes first
|
||||
if (aIndex !== -1 && bIndex === -1) return -1;
|
||||
if (aIndex === -1 && bIndex !== -1) return 1;
|
||||
|
||||
// If neither is featured, sort alphabetically
|
||||
return a.model.toLowerCase().localeCompare(b.model.toLowerCase());
|
||||
});
|
||||
|
||||
featuredModels.sort(
|
||||
(a, b) =>
|
||||
FEATURED_MODELS.indexOf(a.model) - FEATURED_MODELS.indexOf(b.model),
|
||||
);
|
||||
|
||||
remainingModels.sort(alphabeticalSort);
|
||||
|
||||
return airplaneMode
|
||||
? [...featuredModels, ...remainingModels]
|
||||
: [...cloudModels, ...featuredModels, ...remainingModels];
|
||||
}
|
||||
13
cmd/cmd.go
13
cmd/cmd.go
@@ -367,14 +367,25 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
} else if info.RemoteHost != "" {
|
||||
// Cloud model, no need to load/unload
|
||||
|
||||
isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
|
||||
|
||||
// Check if user is signed in for ollama.com cloud models
|
||||
if isCloud {
|
||||
if _, err := client.Whoami(cmd.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
||||
if isCloud {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
101
cmd/cmd_test.go
101
cmd/cmd_test.go
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -1659,3 +1660,103 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
t.Error("Copy Think should not be affected by original modification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteHost string
|
||||
whoamiStatus int
|
||||
whoamiResp any
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ollama.com cloud model - user signed in",
|
||||
remoteHost: "https://ollama.com",
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
},
|
||||
{
|
||||
name: "ollama.com cloud model - user not signed in",
|
||||
remoteHost: "https://ollama.com",
|
||||
whoamiStatus: http.StatusUnauthorized,
|
||||
whoamiResp: map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://ollama.com/signin",
|
||||
},
|
||||
expectedError: "unauthorized",
|
||||
},
|
||||
{
|
||||
name: "non-ollama.com remote - no auth check",
|
||||
remoteHost: "https://other-remote.com",
|
||||
whoamiStatus: http.StatusUnauthorized, // should not be called
|
||||
whoamiResp: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
whoamiCalled := false
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
RemoteHost: tt.remoteHost,
|
||||
RemoteModel: "test-model",
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
case "/api/me":
|
||||
whoamiCalled = true
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.whoamiStatus)
|
||||
if tt.whoamiResp != nil {
|
||||
if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
opts := &runOptions{
|
||||
Model: "test-cloud-model",
|
||||
ShowConnect: false,
|
||||
}
|
||||
|
||||
err := loadOrUnloadModel(cmd, opts)
|
||||
|
||||
if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
|
||||
if !whoamiCalled {
|
||||
t.Error("expected whoami to be called for ollama.com cloud model")
|
||||
}
|
||||
} else {
|
||||
if whoamiCalled {
|
||||
t.Error("whoami should not be called for non-ollama.com remote")
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||
} else {
|
||||
var authErr api.AuthorizationError
|
||||
if !errors.As(err, &authErr) {
|
||||
t.Errorf("expected AuthorizationError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
// TestNumPredict verifies that when num_predict is set, the model generates
|
||||
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
||||
func TestNumPredict(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
|
||||
t.Fatalf("failed to pull model: %v", err)
|
||||
}
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "qwen3:0.6b",
|
||||
Prompt: "Write a long story.",
|
||||
Stream: &stream,
|
||||
Logprobs: true,
|
||||
Options: map[string]any{
|
||||
"num_predict": 10,
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
logprobCount := 0
|
||||
var finalResponse api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
logprobCount += len(resp.Logprobs)
|
||||
if resp.Done {
|
||||
finalResponse = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %v", err)
|
||||
}
|
||||
|
||||
if logprobCount != 10 {
|
||||
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
|
||||
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,6 +175,7 @@ type Tensor interface {
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
SigmoidOut(ctx Context) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
@@ -1468,6 +1468,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
|
||||
@@ -135,7 +135,7 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
||||
// Apply shared expert gating
|
||||
if mlp.SharedGateInp != nil {
|
||||
sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
|
||||
sharedGateVal = sharedGateVal.Sigmoid(ctx)
|
||||
sharedGateVal = sharedGateVal.SigmoidOut(ctx)
|
||||
// Broadcast gate to match dimensions
|
||||
sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
|
||||
sharedOut = sharedOut.Mul(ctx, sharedGateVal)
|
||||
|
||||
@@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
@@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
@@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
seq.numPredicted++
|
||||
if seq.numPredicted == 1 {
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
@@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
|
||||
Reference in New Issue
Block a user