mirror of
https://github.com/ollama/ollama.git
synced 2026-01-17 03:49:12 -05:00
Compare commits
8 Commits
usage
...
mxyng/lint
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b59053a883 | ||
|
|
bb93e5afe7 | ||
|
|
4d24d8a77d | ||
|
|
f01c83ed6d | ||
|
|
d3228355be | ||
|
|
78a75a30d8 | ||
|
|
974ae8ef84 | ||
|
|
efd9f5e67e |
@@ -36,6 +36,12 @@ linters:
|
||||
errcheck:
|
||||
exclude-functions:
|
||||
- fmt.Fprintf
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
# Detects suspicious duplicated sub-expressions.
|
||||
# Prone to false positives when used on cgo code
|
||||
# https://github.com/go-critic/go-critic/issues/897#issuecomment-568892104
|
||||
- dupSubExpr
|
||||
perfsprint:
|
||||
strconcat: false
|
||||
concat-loop: false
|
||||
@@ -45,24 +51,22 @@ linters:
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- -SA1019
|
||||
# Incorrect or missing package comment.
|
||||
# https://staticcheck.dev/docs/checks/#ST1000
|
||||
- -ST1000
|
||||
# Poorly chosen identifier.
|
||||
# https://staticcheck.dev/docs/checks/#ST1003
|
||||
- -ST1003
|
||||
# The documentation of an exported function should start with the function's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1020
|
||||
- -ST1020
|
||||
# The documentation of an exported type should start with type's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1021
|
||||
- -ST1021
|
||||
# The documentation of an exported variable or constant should start with variable's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1022
|
||||
- -ST1022
|
||||
usestdlibvars:
|
||||
http-method: false
|
||||
http-status-code: false
|
||||
exclusions:
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
rules:
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- prealloc
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
|
||||
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -39,7 +40,7 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", v.value)
|
||||
|
||||
client, err := ClientFromEnvironment()
|
||||
if err != v.err {
|
||||
if !errors.Is(err, v.err) {
|
||||
t.Fatalf("expected %s, got %s", v.err, err)
|
||||
}
|
||||
|
||||
|
||||
25
api/types.go
25
api/types.go
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
@@ -308,9 +309,9 @@ func (tp ToolProperty) ToTypeScriptType() string {
|
||||
return mapToTypeScriptType(tp.Type[0])
|
||||
}
|
||||
|
||||
var types []string
|
||||
for _, t := range tp.Type {
|
||||
types = append(types, mapToTypeScriptType(t))
|
||||
types := make([]string, len(tp.Type))
|
||||
for i, t := range tp.Type {
|
||||
types[i] = mapToTypeScriptType(t)
|
||||
}
|
||||
return strings.Join(types, " | ")
|
||||
}
|
||||
@@ -783,7 +784,7 @@ func (m *Metrics) Summary() {
|
||||
|
||||
func (opts *Options) FromMap(m map[string]any) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
typeOpts := reflect.TypeFor[Options]() // types of the fields in the options struct
|
||||
|
||||
// build map of json struct tags to their types
|
||||
jsonOpts := make(map[string]reflect.StructField)
|
||||
@@ -854,8 +855,7 @@ func (opts *Options) FromMap(m map[string]any) error {
|
||||
}
|
||||
field.Set(reflect.ValueOf(slice))
|
||||
case reflect.Pointer:
|
||||
var b bool
|
||||
if field.Type() == reflect.TypeOf(&b) {
|
||||
if field.Type() == reflect.TypeFor[*bool]() {
|
||||
val, ok := val.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type boolean", key)
|
||||
@@ -906,7 +906,7 @@ func DefaultOptions() Options {
|
||||
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
|
||||
type ThinkValue struct {
|
||||
// Value can be a bool or string
|
||||
Value interface{}
|
||||
Value any
|
||||
}
|
||||
|
||||
// IsValid checks if the ThinkValue is valid
|
||||
@@ -999,7 +999,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
|
||||
return errors.New("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
@@ -1018,7 +1018,7 @@ func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
if d.Duration < 0 {
|
||||
return []byte("-1"), nil
|
||||
}
|
||||
return []byte("\"" + d.Duration.String() + "\""), nil
|
||||
return []byte("\"" + d.String() + "\""), nil
|
||||
}
|
||||
|
||||
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
@@ -1045,7 +1045,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
d.Duration = time.Duration(math.MaxInt64)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
|
||||
return fmt.Errorf("unsupported type: '%s'", reflect.TypeOf(v))
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1055,7 +1055,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||
opts := Options{}
|
||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||
typeOpts := reflect.TypeFor[Options]() // types of the fields in the options struct
|
||||
|
||||
// build map of json struct tags to their types
|
||||
jsonOpts := make(map[string]reflect.StructField)
|
||||
@@ -1102,8 +1102,7 @@ func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||
// TODO: only string slices are supported right now
|
||||
out[key] = vals
|
||||
case reflect.Pointer:
|
||||
var b bool
|
||||
if field.Type() == reflect.TypeOf(&b) {
|
||||
if field.Type() == reflect.TypeFor[*bool]() {
|
||||
boolVal, err := strconv.ParseBool(vals[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
var ErrCancelled = errors.New("Cancelled")
|
||||
|
||||
// Cancelled refers to ErrCancelled.
|
||||
//
|
||||
// Deprecated: Use ErrCancelled instead.
|
||||
var Cancelled = ErrCancelled
|
||||
|
||||
@@ -37,7 +38,7 @@ type MsgBuilder struct {
|
||||
}
|
||||
|
||||
// Message initialises a MsgBuilder with the provided message.
|
||||
func Message(format string, args ...interface{}) *MsgBuilder {
|
||||
func Message(format string, args ...any) *MsgBuilder {
|
||||
return &MsgBuilder{Msg: fmt.Sprintf(format, args...)}
|
||||
}
|
||||
|
||||
|
||||
@@ -319,7 +319,7 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timeout scanning server log for inference compute details")
|
||||
return nil, errors.New("timeout scanning server log for inference compute details")
|
||||
default:
|
||||
}
|
||||
file, err := os.Open(serverLogPath)
|
||||
@@ -345,11 +345,9 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||
|
||||
slog.Info("Matched", "inference compute", ic)
|
||||
inference = append(inference, ic)
|
||||
} else {
|
||||
} else if len(inference) > 0 {
|
||||
// Break out on first non matching line after we start matching
|
||||
if len(inference) > 0 {
|
||||
return inference, nil
|
||||
}
|
||||
return inference, nil
|
||||
}
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -31,7 +31,7 @@ func terminate(proc *os.Process) error {
|
||||
func terminated(pid int) (bool, error) {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to find process: %v", err)
|
||||
return false, fmt.Errorf("failed to find process: %w", err)
|
||||
}
|
||||
|
||||
err = proc.Signal(syscall.Signal(0))
|
||||
@@ -40,7 +40,7 @@ func terminated(pid int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("error signaling process: %v", err)
|
||||
return false, fmt.Errorf("error signaling process: %w", err)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
@@ -67,8 +67,7 @@ func reapServers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
pids := strings.Split(pidsStr, "\n")
|
||||
for _, pidStr := range pids {
|
||||
for pidStr := range strings.SplitSeq(pidsStr, "\n") {
|
||||
pidStr = strings.TrimSpace(pidStr)
|
||||
if pidStr == "" {
|
||||
continue
|
||||
|
||||
@@ -5,6 +5,7 @@ package store
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -482,7 +483,8 @@ func (db *database) cleanupOrphanedData() error {
|
||||
}
|
||||
|
||||
func duplicateColumnError(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
var sqlite3Err sqlite3.Error
|
||||
if errors.As(err, &sqlite3Err) {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
||||
}
|
||||
@@ -490,7 +492,8 @@ func duplicateColumnError(err error) bool {
|
||||
}
|
||||
|
||||
func columnNotExists(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
var sqlite3Err sqlite3.Error
|
||||
if errors.As(err, &sqlite3Err) {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "no such column")
|
||||
}
|
||||
@@ -586,8 +589,8 @@ func (db *database) getChatWithOptions(id string, loadAttachmentData bool) (*Cha
|
||||
&browserState,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("chat not found")
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, errors.New("chat not found")
|
||||
}
|
||||
return nil, fmt.Errorf("query chat: %w", err)
|
||||
}
|
||||
@@ -752,7 +755,7 @@ func (db *database) updateLastMessage(chatID string, msg Message) error {
|
||||
return fmt.Errorf("get rows affected: %w", err)
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("no message found to update")
|
||||
return errors.New("no message found to update")
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM attachments WHERE message_id = ?", messageID)
|
||||
|
||||
@@ -282,7 +282,7 @@ func countRows(t *testing.T, db *database, table string) int {
|
||||
return count
|
||||
}
|
||||
|
||||
func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...interface{}) int {
|
||||
func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...any) int {
|
||||
t.Helper()
|
||||
var count int
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", table, condition)
|
||||
@@ -296,7 +296,7 @@ func countRowsWithCondition(t *testing.T, db *database, table, condition string,
|
||||
// Test helpers for schema migration testing
|
||||
|
||||
// schemaMap returns both tables/columns and indexes (ignoring order)
|
||||
func schemaMap(db *database) map[string]interface{} {
|
||||
func schemaMap(db *database) map[string]any {
|
||||
result := make(map[string]any)
|
||||
|
||||
result["tables"] = columnMap(db)
|
||||
|
||||
@@ -5,6 +5,7 @@ package store
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -26,7 +27,7 @@ func (i *Image) Bytes() ([]byte, error) {
|
||||
// ImgBytes reads image data from the specified file path
|
||||
func ImgBytes(path string) ([]byte, error) {
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("empty image path")
|
||||
return nil, errors.New("empty image path")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
|
||||
@@ -4,6 +4,7 @@ package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -130,7 +131,7 @@ func (b *BrowserSearch) Schema() map[string]any {
|
||||
func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("query parameter is required")
|
||||
return nil, "", errors.New("query parameter is required")
|
||||
}
|
||||
|
||||
topn, ok := args["topn"].(int)
|
||||
@@ -150,7 +151,7 @@ func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any,
|
||||
|
||||
searchResponse, ok := result.(*WebSearchResponse)
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("invalid search results format")
|
||||
return nil, "", errors.New("invalid search results format")
|
||||
}
|
||||
|
||||
// Build main search results page that contains all search results
|
||||
@@ -383,15 +384,9 @@ func wrapLines(text string, width int) []string {
|
||||
wrapped = append(wrapped, "")
|
||||
} else if len(line) <= width {
|
||||
wrapped = append(wrapped, line)
|
||||
} else if words := strings.Fields(line); len(words) == 0 {
|
||||
wrapped = append(wrapped, line)
|
||||
} else {
|
||||
// Word wrapping while preserving whitespace structure
|
||||
words := strings.Fields(line)
|
||||
if len(words) == 0 {
|
||||
// Line with only whitespace
|
||||
wrapped = append(wrapped, line)
|
||||
continue
|
||||
}
|
||||
|
||||
currentLine := ""
|
||||
for _, word := range words {
|
||||
// Check if adding this word would exceed width
|
||||
@@ -536,15 +531,13 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
|
||||
}
|
||||
} else {
|
||||
} else if len(b.state.Data.PageStack) != 0 {
|
||||
// get last page
|
||||
if len(b.state.Data.PageStack) != 0 {
|
||||
pageURL := b.state.Data.PageStack[len(b.state.Data.PageStack)-1]
|
||||
var err error
|
||||
page, err = b.getPageFromStack(pageURL)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
|
||||
}
|
||||
pageURL := b.state.Data.PageStack[len(b.state.Data.PageStack)-1]
|
||||
var err error
|
||||
page, err = b.getPageFromStack(pageURL)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -594,7 +587,7 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
|
||||
// Try to get id as integer (link ID from current page)
|
||||
if id, ok := args["id"].(float64); ok {
|
||||
if page == nil {
|
||||
return nil, "", fmt.Errorf("no current page to resolve link from")
|
||||
return nil, "", errors.New("no current page to resolve link from")
|
||||
}
|
||||
idInt := int(id)
|
||||
pageURL, ok := page.Links[idInt]
|
||||
@@ -637,7 +630,7 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
|
||||
|
||||
// If no id provided, just display current page
|
||||
if page == nil {
|
||||
return nil, "", fmt.Errorf("no current page to display")
|
||||
return nil, "", errors.New("no current page to display")
|
||||
}
|
||||
// Only add to PageStack without updating URLToPage
|
||||
b.state.Data.PageStack = append(b.state.Data.PageStack, page.URL)
|
||||
@@ -742,7 +735,7 @@ func (b *BrowserFind) Schema() map[string]any {
|
||||
func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
||||
pattern, ok := args["pattern"].(string)
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("pattern parameter is required")
|
||||
return nil, "", errors.New("pattern parameter is required")
|
||||
}
|
||||
|
||||
// Get cursor parameter if provided, default to current page
|
||||
@@ -756,7 +749,7 @@ func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, st
|
||||
if cursor == -1 {
|
||||
// Use current page
|
||||
if len(b.state.Data.PageStack) == 0 {
|
||||
return nil, "", fmt.Errorf("no pages to search in")
|
||||
return nil, "", errors.New("no pages to search in")
|
||||
}
|
||||
var err error
|
||||
page, err = b.getPageFromStack(b.state.Data.PageStack[len(b.state.Data.PageStack)-1])
|
||||
@@ -776,7 +769,7 @@ func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, st
|
||||
}
|
||||
|
||||
if page == nil {
|
||||
return nil, "", fmt.Errorf("page not found")
|
||||
return nil, "", errors.New("page not found")
|
||||
}
|
||||
|
||||
// Create find results page
|
||||
|
||||
@@ -5,6 +5,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
@@ -87,7 +88,7 @@ func (g *BrowserCrawler) Schema() map[string]any {
|
||||
func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*CrawlResponse, error) {
|
||||
urlsRaw, ok := args["urls"].([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("urls parameter is required and must be an array of strings")
|
||||
return nil, errors.New("urls parameter is required and must be an array of strings")
|
||||
}
|
||||
|
||||
urls := make([]string, 0, len(urlsRaw))
|
||||
@@ -98,7 +99,7 @@ func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*Cra
|
||||
}
|
||||
|
||||
if len(urls) == 0 {
|
||||
return nil, fmt.Errorf("at least one URL is required")
|
||||
return nil, errors.New("at least one URL is required")
|
||||
}
|
||||
|
||||
return g.performWebCrawl(ctx, urls)
|
||||
|
||||
@@ -5,6 +5,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -84,7 +85,7 @@ func (w *BrowserWebSearch) Schema() map[string]any {
|
||||
func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (any, error) {
|
||||
queriesRaw, ok := args["queries"].([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("queries parameter is required and must be an array of strings")
|
||||
return nil, errors.New("queries parameter is required and must be an array of strings")
|
||||
}
|
||||
|
||||
queries := make([]string, 0, len(queriesRaw))
|
||||
@@ -95,7 +96,7 @@ func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (an
|
||||
}
|
||||
|
||||
if len(queries) == 0 {
|
||||
return nil, fmt.Errorf("at least one query is required")
|
||||
return nil, errors.New("at least one query is required")
|
||||
}
|
||||
|
||||
maxResults := 5
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -36,7 +37,7 @@ func (w *WebFetch) Description() string {
|
||||
return "Crawl and extract text content from web pages"
|
||||
}
|
||||
|
||||
func (g *WebFetch) Schema() map[string]any {
|
||||
func (w *WebFetch) Schema() map[string]any {
|
||||
schemaBytes := []byte(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -61,11 +62,11 @@ func (w *WebFetch) Prompt() string {
|
||||
func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
||||
urlRaw, ok := args["url"]
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("url parameter is required")
|
||||
return nil, "", errors.New("url parameter is required")
|
||||
}
|
||||
urlStr, ok := urlRaw.(string)
|
||||
if !ok || strings.TrimSpace(urlStr) == "" {
|
||||
return nil, "", fmt.Errorf("url must be a non-empty string")
|
||||
return nil, "", errors.New("url must be a non-empty string")
|
||||
}
|
||||
|
||||
result, err := performWebFetch(ctx, urlStr)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -45,7 +46,7 @@ func (w *WebSearch) Prompt() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (g *WebSearch) Schema() map[string]any {
|
||||
func (w *WebSearch) Schema() map[string]any {
|
||||
schemaBytes := []byte(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -71,12 +72,12 @@ func (g *WebSearch) Schema() map[string]any {
|
||||
func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
||||
rawQuery, ok := args["query"]
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("query parameter is required")
|
||||
return nil, "", errors.New("query parameter is required")
|
||||
}
|
||||
|
||||
queryStr, ok := rawQuery.(string)
|
||||
if !ok || strings.TrimSpace(queryStr) == "" {
|
||||
return nil, "", fmt.Errorf("query must be a non-empty string")
|
||||
return nil, "", errors.New("query must be a non-empty string")
|
||||
}
|
||||
|
||||
maxResults := 5
|
||||
|
||||
@@ -19,10 +19,12 @@ import (
|
||||
// Errors wrapping Found should provide additional context, e.g.
|
||||
// fmt.Errorf("%w: %s", not.Found, key)
|
||||
//
|
||||
//nolint:staticcheck
|
||||
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Found.
|
||||
var Found = errors.New("not found")
|
||||
|
||||
// Available is an error that indicates that a value is not available.
|
||||
//
|
||||
//nolint:staticcheck
|
||||
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Available.
|
||||
var Available = errors.New("not available")
|
||||
|
||||
@@ -4,6 +4,7 @@ package not
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ValidError struct {
|
||||
@@ -44,12 +45,12 @@ func (b Valids) Error() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result string
|
||||
var sb strings.Builder
|
||||
for i, err := range b {
|
||||
if i > 0 {
|
||||
result += "; "
|
||||
sb.WriteString("; ")
|
||||
}
|
||||
result += err.Error()
|
||||
sb.WriteString(err.Error())
|
||||
}
|
||||
return result
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -73,7 +74,7 @@ func extractPDFText(data []byte) (string, error) {
|
||||
if strings.TrimSpace(text) != "" {
|
||||
if textBuilder.Len() > 0 {
|
||||
textBuilder.WriteString("\n\n--- Page ")
|
||||
textBuilder.WriteString(fmt.Sprintf("%d", i))
|
||||
textBuilder.WriteString(strconv.Itoa(i))
|
||||
textBuilder.WriteString(" ---\n")
|
||||
}
|
||||
textBuilder.WriteString(text)
|
||||
|
||||
32
app/ui/ui.go
32
app/ui/ui.go
@@ -194,7 +194,7 @@ func (s *Server) Handler() http.Handler {
|
||||
log := s.log()
|
||||
level := slog.LevelInfo
|
||||
start := time.Now()
|
||||
requestID := fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
requestID := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
|
||||
defer func() {
|
||||
p := recover()
|
||||
@@ -204,7 +204,7 @@ func (s *Server) Handler() http.Handler {
|
||||
|
||||
// Handle panic with user-friendly error
|
||||
if !sw.Written() {
|
||||
s.handleError(sw, fmt.Errorf("internal server error"))
|
||||
s.handleError(sw, errors.New("internal server error"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -382,7 +382,7 @@ func waitForServer(ctx context.Context) error {
|
||||
break
|
||||
}
|
||||
if time.Now().After(timeout) {
|
||||
return fmt.Errorf("timeout waiting for Ollama server to be ready")
|
||||
return errors.New("timeout waiting for Ollama server to be ready")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
@@ -455,7 +455,7 @@ func (s *Server) checkModelUpstream(ctx context.Context, modelName string, timeo
|
||||
|
||||
digest := resp.Header.Get("ollama-content-digest")
|
||||
if digest == "" {
|
||||
return "", 0, fmt.Errorf("no digest header found")
|
||||
return "", 0, errors.New("no digest header found")
|
||||
}
|
||||
|
||||
var pushTime int64
|
||||
@@ -598,12 +598,12 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
return fmt.Errorf("empty model")
|
||||
return errors.New("empty model")
|
||||
}
|
||||
|
||||
// Don't allow empty messages unless forceUpdate is true
|
||||
if req.Prompt == "" && !req.ForceUpdate {
|
||||
return fmt.Errorf("empty message")
|
||||
return errors.New("empty message")
|
||||
}
|
||||
|
||||
if createdChat {
|
||||
@@ -942,7 +942,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||
} else {
|
||||
onlyStandalone := true
|
||||
for _, tc := range res.Message.ToolCalls {
|
||||
if !(tc.Function.Name == "web_search" || tc.Function.Name == "web_fetch") {
|
||||
if tc.Function.Name != "web_search" && tc.Function.Name != "web_fetch" {
|
||||
onlyStandalone = false
|
||||
break
|
||||
}
|
||||
@@ -1194,7 +1194,7 @@ func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
|
||||
cid := r.PathValue("id")
|
||||
|
||||
if cid == "" {
|
||||
return fmt.Errorf("chat ID is required")
|
||||
return errors.New("chat ID is required")
|
||||
}
|
||||
|
||||
chat, err := s.Store.Chat(cid)
|
||||
@@ -1252,7 +1252,7 @@ func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
|
||||
func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
|
||||
cid := r.PathValue("id")
|
||||
if cid == "" {
|
||||
return fmt.Errorf("chat ID is required")
|
||||
return errors.New("chat ID is required")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
@@ -1283,7 +1283,7 @@ func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
|
||||
func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
|
||||
cid := r.PathValue("id")
|
||||
if cid == "" {
|
||||
return fmt.Errorf("chat ID is required")
|
||||
return errors.New("chat ID is required")
|
||||
}
|
||||
|
||||
// Check if the chat exists (no need to load attachments)
|
||||
@@ -1291,7 +1291,7 @@ func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
|
||||
if err != nil {
|
||||
if errors.Is(err, not.Found) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return fmt.Errorf("chat not found")
|
||||
return errors.New("chat not found")
|
||||
}
|
||||
return fmt.Errorf("failed to get chat: %w", err)
|
||||
}
|
||||
@@ -1592,7 +1592,7 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
|
||||
|
||||
func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return fmt.Errorf("method not allowed")
|
||||
return errors.New("method not allowed")
|
||||
}
|
||||
|
||||
var req struct {
|
||||
@@ -1603,7 +1603,7 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
return fmt.Errorf("model is required")
|
||||
return errors.New("model is required")
|
||||
}
|
||||
|
||||
digest, pushTime, err := s.checkModelUpstream(r.Context(), req.Model, 5*time.Second)
|
||||
@@ -1730,8 +1730,8 @@ func supportsWebSearchTools(model string) bool {
|
||||
|
||||
// buildChatRequest converts store.Chat to api.ChatRequest
|
||||
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
||||
var msgs []api.Message
|
||||
for _, m := range chat.Messages {
|
||||
msgs := make([]api.Message, len(chat.Messages))
|
||||
for i, m := range chat.Messages {
|
||||
// Skip empty messages if present
|
||||
if m.Content == "" && m.Thinking == "" && len(m.ToolCalls) == 0 && len(m.Attachments) == 0 {
|
||||
continue
|
||||
@@ -1789,7 +1789,7 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
|
||||
s.log().Debug("unknown message role", "role", m.Role)
|
||||
}
|
||||
|
||||
msgs = append(msgs, apiMsg)
|
||||
msgs[i] = apiMsg
|
||||
}
|
||||
|
||||
var thinkValue *api.ThinkValue
|
||||
|
||||
@@ -198,7 +198,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||
_, err = os.Stat(filepath.Dir(stageFilename))
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.MkdirAll(filepath.Dir(stageFilename), 0o755); err != nil {
|
||||
return fmt.Errorf("create ollama dir %s: %v", filepath.Dir(stageFilename), err)
|
||||
return fmt.Errorf("create ollama dir %s: %w", filepath.Dir(stageFilename), err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,7 +218,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||
|
||||
if err := VerifyDownload(); err != nil {
|
||||
_ = os.Remove(stageFilename)
|
||||
return fmt.Errorf("%s - %s", resp.Request.URL.String(), err)
|
||||
return fmt.Errorf("%s - %w", resp.Request.URL.String(), err)
|
||||
}
|
||||
UpdateDownloaded = true
|
||||
return nil
|
||||
|
||||
@@ -92,7 +92,7 @@ func DoUpgrade(interactive bool) error {
|
||||
|
||||
bundle := getStagedUpdate()
|
||||
if bundle == "" {
|
||||
return fmt.Errorf("failed to lookup downloads")
|
||||
return errors.New("failed to lookup downloads")
|
||||
}
|
||||
|
||||
slog.Info("starting upgrade", "app", BundlePath, "update", bundle, "pid", os.Getpid(), "log", UpgradeLogFile)
|
||||
@@ -107,7 +107,7 @@ func DoUpgrade(interactive bool) error {
|
||||
// Verify old doesn't exist yet
|
||||
if _, err := os.Stat(contentsOldName); err == nil {
|
||||
slog.Error("prior upgrade failed", "backup", contentsOldName)
|
||||
return fmt.Errorf("prior upgrade failed - please upgrade manually by installing the bundle")
|
||||
return errors.New("prior upgrade failed - please upgrade manually by installing the bundle")
|
||||
}
|
||||
if err := os.MkdirAll(appBackupDir, 0o755); err != nil {
|
||||
return fmt.Errorf("unable to create backup dir %s: %w", appBackupDir, err)
|
||||
@@ -133,7 +133,7 @@ func DoUpgrade(interactive bool) error {
|
||||
return err
|
||||
}
|
||||
if !chownWithAuthorization(u.Username) {
|
||||
return fmt.Errorf("unable to change permissions to complete upgrade")
|
||||
return errors.New("unable to change permissions to complete upgrade")
|
||||
}
|
||||
if err := os.Rename(BundlePath, appBackup); err != nil {
|
||||
return fmt.Errorf("unable to perform upgrade - failed to stage old version: %w", err)
|
||||
@@ -264,7 +264,7 @@ func DoPostUpgradeCleanup() error {
|
||||
func verifyDownload() error {
|
||||
bundle := getStagedUpdate()
|
||||
if bundle == "" {
|
||||
return fmt.Errorf("failed to lookup downloads")
|
||||
return errors.New("failed to lookup downloads")
|
||||
}
|
||||
slog.Debug("verifying update", "bundle", bundle)
|
||||
|
||||
@@ -338,7 +338,7 @@ func verifyDownload() error {
|
||||
}
|
||||
|
||||
if err := verifyExtractedBundle(filepath.Join(dir, "Ollama.app")); err != nil {
|
||||
return fmt.Errorf("signature verification failed: %s", err)
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -347,11 +347,11 @@ func verifyDownload() error {
|
||||
func DoUpgradeAtStartup() error {
|
||||
bundle := getStagedUpdate()
|
||||
if bundle == "" {
|
||||
return fmt.Errorf("failed to lookup downloads")
|
||||
return errors.New("failed to lookup downloads")
|
||||
}
|
||||
|
||||
if BundlePath == "" {
|
||||
return fmt.Errorf("unable to upgrade at startup, app in development mode")
|
||||
return errors.New("unable to upgrade at startup, app in development mode")
|
||||
}
|
||||
|
||||
// [Re]verify before proceeding
|
||||
|
||||
@@ -22,9 +22,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
fmt.Fprintf(w, `{"version": "9.9.9", "url": "%s"}`, server.URL+"/9.9.9/"+Installer)
|
||||
// TODO - wire up the redirects to mimic real behavior
|
||||
} else {
|
||||
slog.Debug("unexpected request", "url", r.URL)
|
||||
@@ -67,17 +65,16 @@ func TestBackgoundChecker(t *testing.T) {
|
||||
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/update.json" {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
switch r.URL.Path {
|
||||
case "/update.json":
|
||||
fmt.Fprintf(w, `{"version": "9.9.9", "url": "%s"}`, server.URL+"/9.9.9/"+Installer)
|
||||
// TODO - wire up the redirects to mimic real behavior
|
||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
||||
case "/9.9.9/" + Installer:
|
||||
buf := &bytes.Buffer{}
|
||||
zw := zip.NewWriter(buf)
|
||||
zw.Close()
|
||||
io.Copy(w, buf)
|
||||
} else {
|
||||
default:
|
||||
slog.Debug("unexpected request", "url", r.URL)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -149,7 +149,7 @@ func BenchmarkChat(fOpt flagOptions) error {
|
||||
|
||||
for _, model := range models {
|
||||
for range *fOpt.epochs {
|
||||
options := make(map[string]interface{})
|
||||
options := make(map[string]any)
|
||||
if *fOpt.maxTokens > 0 {
|
||||
options["num_predict"] = *fOpt.maxTokens
|
||||
}
|
||||
|
||||
@@ -442,7 +442,7 @@ func TestReadImage_FileNotFound(t *testing.T) {
|
||||
func TestOptionsMapCreation(t *testing.T) {
|
||||
fOpt := createTestFlagOptions()
|
||||
|
||||
options := make(map[string]interface{})
|
||||
options := make(map[string]any)
|
||||
if *fOpt.maxTokens > 0 {
|
||||
options["num_predict"] = *fOpt.maxTokens
|
||||
}
|
||||
|
||||
35
cmd/cmd.go
35
cmd/cmd.go
@@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"maps"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -203,7 +204,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if err := client.Create(cmd.Context(), req, fn); err != nil {
|
||||
if strings.Contains(err.Error(), "path or Modelfile are required") {
|
||||
return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client")
|
||||
return errors.New("the ollama server must be updated to use `ollama create` with this client")
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -990,7 +991,7 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
var v string
|
||||
switch vData := resp.ModelInfo[k].(type) {
|
||||
case bool:
|
||||
v = fmt.Sprintf("%t", vData)
|
||||
v = strconv.FormatBool(vData)
|
||||
case string:
|
||||
v = vData
|
||||
case float64:
|
||||
@@ -1204,9 +1205,7 @@ func (r runOptions) Copy() runOptions {
|
||||
var opts map[string]any
|
||||
if r.Options != nil {
|
||||
opts = make(map[string]any, len(r.Options))
|
||||
for k, v := range r.Options {
|
||||
opts[k] = v
|
||||
}
|
||||
maps.Copy(opts, r.Options)
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
@@ -1330,12 +1329,12 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
state := &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
thinkTagOpened := false
|
||||
thinkTagClosed := false
|
||||
|
||||
role := "assistant"
|
||||
|
||||
@@ -1463,10 +1462,10 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
state := &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
thinkTagOpened := false
|
||||
thinkTagClosed := false
|
||||
|
||||
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
|
||||
|
||||
@@ -1634,7 +1633,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := client.Heartbeat(cmd.Context()); err != nil {
|
||||
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) {
|
||||
if !strings.Contains(err.Error(), " refused") && !strings.Contains(err.Error(), "could not connect") {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
@@ -1952,13 +1951,13 @@ func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicit
|
||||
}
|
||||
|
||||
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
out := ""
|
||||
var sb strings.Builder
|
||||
formatExplanation := ""
|
||||
formatValues := ""
|
||||
if !plainText {
|
||||
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||
formatValues = readline.ColorDefault
|
||||
out += formatExplanation
|
||||
sb.WriteString(formatExplanation)
|
||||
}
|
||||
for i, toolCall := range toolCalls {
|
||||
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||
@@ -1966,13 +1965,13 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
return ""
|
||||
}
|
||||
if i > 0 {
|
||||
out += "\n"
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
// all tool calls are unexpected since we don't currently support registering any in the CLI
|
||||
out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||
fmt.Fprintf(&sb, " Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||
}
|
||||
if !plainText {
|
||||
out += readline.ColorDefault
|
||||
sb.WriteString(readline.ColorDefault)
|
||||
}
|
||||
return out
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -307,7 +308,7 @@ func TestDeleteHandler(t *testing.T) {
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
errPayload := `{"error":"model '%s' not found"}`
|
||||
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
|
||||
fmt.Fprintf(w, errPayload, req.Name)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -761,8 +762,8 @@ func TestGetModelfileName(t *testing.T) {
|
||||
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
|
||||
}
|
||||
|
||||
if tt.expectedErr != os.ErrNotExist {
|
||||
if actualErr != tt.expectedErr {
|
||||
if !errors.Is(tt.expectedErr, os.ErrNotExist) {
|
||||
if !errors.Is(actualErr, tt.expectedErr) {
|
||||
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
||||
}
|
||||
} else {
|
||||
@@ -924,10 +925,8 @@ func TestPushHandler(t *testing.T) {
|
||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
} else if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1014,10 +1013,8 @@ func TestListHandler(t *testing.T) {
|
||||
if got := string(output); got != tt.expectedOutput {
|
||||
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
|
||||
}
|
||||
} else {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
} else if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1322,8 +1319,8 @@ func TestRunOptions_Copy(t *testing.T) {
|
||||
// Test 2: Verify all fields are copied correctly
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
got any
|
||||
want any
|
||||
}{
|
||||
{"Model", copied.Model, original.Model},
|
||||
{"ParentModel", copied.ParentModel, original.ParentModel},
|
||||
|
||||
@@ -130,7 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
var sb strings.Builder
|
||||
var multiline MultilineState
|
||||
var thinkExplicitlySet bool = opts.Think != nil
|
||||
thinkExplicitlySet := opts.Think != nil
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -410,7 +410,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println(" No additional parameters were specified for this model.")
|
||||
} else {
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
for l := range strings.SplitSeq(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
}
|
||||
@@ -576,9 +576,8 @@ func extractFileNames(input string) []string {
|
||||
|
||||
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
filePaths := extractFileNames(input)
|
||||
var imgs []api.ImageData
|
||||
|
||||
for _, fp := range filePaths {
|
||||
imgs := make([]api.ImageData, len(filePaths))
|
||||
for i, fp := range filePaths {
|
||||
nfp := normalizeFilePath(fp)
|
||||
data, err := getImageData(nfp)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
@@ -591,7 +590,7 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
imgs[i] = data
|
||||
}
|
||||
return strings.TrimSpace(input), imgs, nil
|
||||
}
|
||||
|
||||
@@ -38,10 +38,10 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"tokenizer.ggml.pre": t.Pre,
|
||||
"tokenizer.ggml.model": t.Vocabulary.Model,
|
||||
"tokenizer.ggml.tokens": t.Vocabulary.Tokens,
|
||||
"tokenizer.ggml.scores": t.Vocabulary.Scores,
|
||||
"tokenizer.ggml.token_type": t.Vocabulary.Types,
|
||||
"tokenizer.ggml.model": t.Model,
|
||||
"tokenizer.ggml.tokens": t.Tokens,
|
||||
"tokenizer.ggml.scores": t.Scores,
|
||||
"tokenizer.ggml.token_type": t.Types,
|
||||
}
|
||||
|
||||
if len(t.Merges) > 0 {
|
||||
@@ -231,20 +231,20 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
|
||||
switch {
|
||||
case vocabSize == 0:
|
||||
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||
case vocabSize > len(t.Vocabulary.Tokens):
|
||||
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
||||
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
|
||||
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
|
||||
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
|
||||
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Tokens))
|
||||
case vocabSize > len(t.Tokens):
|
||||
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Tokens))
|
||||
for i := range vocabSize - len(t.Tokens) {
|
||||
t.Tokens = append(t.Tokens, fmt.Sprintf("[PAD%d]", i))
|
||||
t.Scores = append(t.Scores, -1)
|
||||
t.Types = append(t.Types, tokenTypeUserDefined)
|
||||
}
|
||||
case vocabSize < len(t.Vocabulary.Tokens):
|
||||
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens))
|
||||
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
||||
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
||||
case vocabSize < len(t.Tokens):
|
||||
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Tokens))
|
||||
p.VocabSize = uint32(len(t.Tokens))
|
||||
p.TextModel.VocabSize = uint32(len(t.Tokens))
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
slog.Debug("vocabulary", "size", len(t.Tokens))
|
||||
}
|
||||
|
||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||
|
||||
@@ -137,7 +137,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
out := make([]*ggml.Tensor, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
if slices.Contains([]string{
|
||||
"embeddings.position_ids",
|
||||
|
||||
@@ -44,14 +44,14 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
out := make([]*ggml.Tensor, len(ts))
|
||||
for i, t := range ts {
|
||||
out[i] = &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -43,18 +43,18 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out := make([]*ggml.Tensor, len(ts))
|
||||
for i, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||
t.SetRepacker(p.addOne)
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out[i] = &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -22,8 +22,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out := make([]*ggml.Tensor, len(ts))
|
||||
for i, t := range ts {
|
||||
shape := t.Shape()
|
||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
|
||||
@@ -31,12 +31,12 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out[i] = &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -111,7 +111,7 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
if !strings.HasSuffix(name, ".weight") {
|
||||
name = name + ".weight"
|
||||
name += ".weight"
|
||||
}
|
||||
if strings.Contains(name, "ffn_down_exps") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
|
||||
@@ -127,7 +127,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
out := make([]*ggml.Tensor, 0, len(ts)+1)
|
||||
|
||||
if p.RopeScaling.factors != nil {
|
||||
out = append(out, &ggml.Tensor{
|
||||
@@ -176,9 +176,9 @@ func (p *llamaModel) Replacements() []string {
|
||||
}
|
||||
|
||||
func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
var dims []int
|
||||
for _, dim := range shape {
|
||||
dims = append(dims, int(dim))
|
||||
dims := make([]int, len(shape))
|
||||
for i, dim := range shape {
|
||||
dims[i] = int(dim)
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
|
||||
@@ -30,8 +30,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out := make([]*ggml.Tensor, len(ts))
|
||||
for i, t := range ts {
|
||||
shape := t.Shape()
|
||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
|
||||
@@ -41,12 +41,12 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out[i] = &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -90,9 +90,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
out := make([]*ggml.Tensor, len(ts))
|
||||
for i, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") {
|
||||
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||
@@ -100,12 +99,12 @@ func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
out[i] = &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -145,9 +144,9 @@ func (p *mistral3Model) Replacements() []string {
|
||||
}
|
||||
|
||||
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
var dims []int
|
||||
for _, dim := range shape {
|
||||
dims = append(dims, int(dim))
|
||||
dims := make([]int, len(shape))
|
||||
for i, dim := range shape {
|
||||
dims[i] = int(dim)
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
|
||||
@@ -49,20 +49,20 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
out := make([]*ggml.Tensor, len(ts))
|
||||
for i, t := range ts {
|
||||
out[i] = &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *qwen2Model) Replacements() []string {
|
||||
func (q *qwen2Model) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
|
||||
@@ -90,9 +90,9 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *qwen25VLModel) Replacements() []string {
|
||||
func (q *qwen25VLModel) Replacements() []string {
|
||||
return append(
|
||||
p.qwen2Model.Replacements(),
|
||||
q.qwen2Model.Replacements(),
|
||||
"visual", "v",
|
||||
"blocks", "blk",
|
||||
"attn.proj", "attn_out",
|
||||
|
||||
@@ -54,6 +54,6 @@ func (t torch) Clone() Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
||||
func (t torch) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
content string
|
||||
}
|
||||
|
||||
var ts []t
|
||||
ts := make([]t, 0, len(atm))
|
||||
for content, id := range atm {
|
||||
ts = append(ts, t{id, content})
|
||||
}
|
||||
|
||||
@@ -300,9 +300,9 @@ func (s Tensors) Items(prefix ...string) []*Tensor {
|
||||
return items
|
||||
}
|
||||
|
||||
func (ts Tensors) GroupLayers() map[string]Layer {
|
||||
func (s Tensors) GroupLayers() map[string]Layer {
|
||||
layers := make(map[string]Layer)
|
||||
for _, t := range ts.items {
|
||||
for _, t := range s.items {
|
||||
parts := strings.Split(t.Name, ".")
|
||||
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
||||
if len(parts) > index+2 {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -225,7 +226,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
Name: name,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
Shape: shape,
|
||||
}
|
||||
|
||||
llm.tensors = append(llm.tensors, &tensor)
|
||||
@@ -511,7 +512,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
return errors.New("architecture not set")
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||
|
||||
@@ -136,8 +136,8 @@ func (t FileType) Value() uint32 {
|
||||
return uint32(t)
|
||||
}
|
||||
|
||||
func (ftype FileType) ToTensorType() TensorType {
|
||||
switch ftype {
|
||||
func (t FileType) ToTensorType() TensorType {
|
||||
switch t {
|
||||
case FileTypeF32:
|
||||
return TensorTypeF32
|
||||
case FileTypeF16:
|
||||
@@ -177,7 +177,7 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
slog.Warn("unsupported file type", "type", t)
|
||||
return 0 // F32
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ type KeyValue struct {
|
||||
}
|
||||
|
||||
func (kv KeyValue) Valid() bool {
|
||||
return kv.Key != "" && kv.Value.value != nil
|
||||
return kv.Key != "" && kv.value != nil
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
|
||||
@@ -200,9 +200,7 @@ func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
|
||||
before := raw[:channelIndex]
|
||||
after := raw[channelIndex+len("<|channel|>"):]
|
||||
// the channel name is `after` all the way up to the first (if any) whitespace character
|
||||
idx := strings.IndexFunc(after, func(r rune) bool {
|
||||
return unicode.IsSpace(r)
|
||||
})
|
||||
idx := strings.IndexFunc(after, unicode.IsSpace)
|
||||
if idx == -1 {
|
||||
idx = len(after)
|
||||
}
|
||||
@@ -319,11 +317,12 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
||||
}
|
||||
case HarmonyEventContentEmitted:
|
||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||
if h.state == harmonyMessageState_Normal {
|
||||
switch h.state {
|
||||
case harmonyMessageState_Normal:
|
||||
contentSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_Thinking {
|
||||
case harmonyMessageState_Thinking:
|
||||
thinkingSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_ToolCalling {
|
||||
case harmonyMessageState_ToolCalling:
|
||||
toolContentSb.WriteString(event.Content)
|
||||
}
|
||||
case HarmonyEventMessageEnd:
|
||||
|
||||
@@ -263,9 +263,9 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||
cparams.use_mmap = C.bool(params.UseMmap)
|
||||
cparams.vocab_only = C.bool(params.VocabOnly)
|
||||
|
||||
var devices []C.ggml_backend_dev_t
|
||||
for _, llamaID := range params.Devices {
|
||||
devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID)))
|
||||
devices := make([]C.ggml_backend_dev_t, len(params.Devices))
|
||||
for i, llamaID := range params.Devices {
|
||||
devices[i] = C.ggml_backend_dev_get(C.size_t(llamaID))
|
||||
}
|
||||
if len(devices) > 0 {
|
||||
devices = append(devices, C.ggml_backend_dev_t(C.NULL))
|
||||
|
||||
@@ -250,7 +250,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
||||
err := fmt.Errorf("error starting runner: %w %s", err, msg)
|
||||
if llamaModel != nil {
|
||||
llama.FreeModel(llamaModel)
|
||||
}
|
||||
@@ -846,14 +846,7 @@ nextOperation:
|
||||
func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
|
||||
devices := []ml.DeviceID{}
|
||||
for _, layer := range gpuLayers {
|
||||
new := true
|
||||
for _, ID := range devices {
|
||||
if layer.DeviceID == ID {
|
||||
new = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if new {
|
||||
if !slices.Contains(devices, layer.DeviceID) {
|
||||
devices = append(devices, layer.DeviceID)
|
||||
}
|
||||
}
|
||||
@@ -989,13 +982,11 @@ nextLayer:
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
|
||||
return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available))
|
||||
}
|
||||
} else {
|
||||
if vramSize > systemInfo.TotalMemory {
|
||||
// disable partial offloading when model is greater than total system memory as this
|
||||
// can lead to locking up the system
|
||||
s.options.NumGPU = 0
|
||||
gpuLayers = ml.GPULayersList{}
|
||||
}
|
||||
} else if vramSize > systemInfo.TotalMemory {
|
||||
// disable partial offloading when model is greater than total system memory as this
|
||||
// can lead to locking up the system
|
||||
s.options.NumGPU = 0
|
||||
gpuLayers = ml.GPULayersList{}
|
||||
}
|
||||
|
||||
if gpuLayers.Sum() == 0 {
|
||||
@@ -1218,7 +1209,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
|
||||
if err != nil {
|
||||
return ServerStatusError, fmt.Errorf("error creating GET request: %v", err)
|
||||
return ServerStatusError, fmt.Errorf("error creating GET request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
@@ -1481,7 +1472,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
return errors.New("invalid JSON schema in format")
|
||||
}
|
||||
req.Grammar = string(g)
|
||||
}
|
||||
@@ -1521,13 +1512,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(req); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
return fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating POST request: %v", err)
|
||||
return fmt.Errorf("error creating POST request: %w", err)
|
||||
}
|
||||
serverReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
@@ -1576,7 +1567,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
|
||||
var c CompletionResponse
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %w", err)
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
@@ -1618,7 +1609,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("an error was encountered while running the model: %s", msg)
|
||||
}
|
||||
|
||||
return fmt.Errorf("error reading llm response: %v", err)
|
||||
return fmt.Errorf("error reading llm response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1693,7 +1684,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
if s.llamaModel == nil {
|
||||
return nil, fmt.Errorf("no tokenizer configured")
|
||||
return nil, errors.New("no tokenizer configured")
|
||||
}
|
||||
|
||||
return s.llamaModel.Tokenize(content, false, true)
|
||||
@@ -1718,15 +1709,15 @@ func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
if s.llamaModel == nil {
|
||||
return "", fmt.Errorf("no tokenizer configured")
|
||||
return "", errors.New("no tokenizer configured")
|
||||
}
|
||||
|
||||
var resp string
|
||||
var sb strings.Builder
|
||||
for _, token := range tokens {
|
||||
resp += s.llamaModel.TokenToPiece(token)
|
||||
sb.WriteString(s.llamaModel.TokenToPiece(token))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
|
||||
@@ -209,7 +209,7 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
}
|
||||
|
||||
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
|
||||
if err != tt.expectedErr {
|
||||
if !errors.Is(err, tt.expectedErr) {
|
||||
t.Fatalf("fitGPU returned error: %v", err)
|
||||
}
|
||||
if gpuLayers.Hash() != tt.expected.Hash() {
|
||||
|
||||
@@ -84,7 +84,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
@@ -150,7 +150,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -164,7 +164,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -189,7 +189,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
@@ -214,7 +214,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
@@ -240,7 +240,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
@@ -265,7 +265,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func TestEmbeddingsMiddleware_EncodingFormats(t *testing.T) {
|
||||
|
||||
switch tc.expectType {
|
||||
case "array":
|
||||
if _, ok := result.Data[0].Embedding.([]interface{}); !ok {
|
||||
if _, ok := result.Data[0].Embedding.([]any); !ok {
|
||||
t.Errorf("expected array, got %T", result.Data[0].Embedding)
|
||||
}
|
||||
case "string":
|
||||
@@ -210,10 +210,8 @@ func TestEmbeddingsMiddleware_InvalidEncodingFormat(t *testing.T) {
|
||||
if !strings.Contains(errResp.Error.Message, "encoding_format") {
|
||||
t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message)
|
||||
}
|
||||
} else {
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
} else if resp.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -845,19 +845,17 @@ func TestListMiddleware(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var expected, actual map[string]any
|
||||
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||
if err != nil {
|
||||
var want, got map[string]any
|
||||
if err := json.Unmarshal([]byte(tc.resp), &want); err != nil {
|
||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("response does not match (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
@@ -92,7 +93,7 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||
return backend(modelPath, params)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
return nil, errors.New("unsupported backend")
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
|
||||
@@ -178,14 +178,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
requiredMemory.CPU.Cache = make([]uint64, blocks+1)
|
||||
|
||||
// create list of buffer types for each gpu
|
||||
var gpuDeviceBufferTypes []deviceBufferType
|
||||
gpuDeviceBufferTypes := make([]deviceBufferType, len(gpus))
|
||||
requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
|
||||
for i, d := range gpus {
|
||||
bt := C.ggml_backend_dev_buffer_type(d)
|
||||
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
|
||||
gpuDeviceBufferTypes[i] = deviceBufferType{
|
||||
d: d,
|
||||
bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...),
|
||||
})
|
||||
}
|
||||
|
||||
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
|
||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||
@@ -354,8 +354,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
deviceBufferTypes := make(map[C.ggml_backend_dev_t]C.ggml_backend_buffer_type_t)
|
||||
|
||||
// create backends and buffer types used for the compute graph scheduler
|
||||
var schedBackends []C.ggml_backend_t
|
||||
var schedBufts []C.ggml_backend_buffer_type_t
|
||||
schedBackends := make([]C.ggml_backend_t, 0, len(cpus)+len(accels)+len(gpus))
|
||||
schedBufts := make([]C.ggml_backend_buffer_type_t, 0, len(cpus)+len(accels)+len(gpus))
|
||||
for _, d := range append(gpus, append(accels, cpus...)...) {
|
||||
b := backends[d]
|
||||
bt := C.ggml_backend_get_default_buffer_type(b)
|
||||
|
||||
38
ml/device.go
38
ml/device.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
@@ -218,7 +219,7 @@ type BackendMemory struct {
|
||||
}
|
||||
|
||||
func (m BackendMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
attrs := make([]slog.Attr, 0, 2+len(m.GPUs))
|
||||
if m.InputWeights != 0 {
|
||||
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||
}
|
||||
@@ -414,14 +415,7 @@ func LibraryPaths(l []DeviceInfo) []string {
|
||||
gpuLibs := []string{LibOllamaPath}
|
||||
for _, gpu := range l {
|
||||
for _, dir := range gpu.LibraryPath {
|
||||
needed := true
|
||||
for _, existing := range gpuLibs {
|
||||
if dir == existing {
|
||||
needed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needed {
|
||||
if !slices.Contains(gpuLibs, dir) {
|
||||
gpuLibs = append(gpuLibs, dir)
|
||||
}
|
||||
}
|
||||
@@ -437,15 +431,15 @@ const (
|
||||
DuplicateDevice // The same physical device but different library/backend (overlapping device)
|
||||
)
|
||||
|
||||
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||
if a.PCIID != b.PCIID {
|
||||
func (d DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||
if d.PCIID != b.PCIID {
|
||||
return UniqueDevice
|
||||
}
|
||||
// If PCIID is empty, we have to use ID + library for uniqueness
|
||||
if a.PCIID == "" && a.DeviceID != b.DeviceID {
|
||||
if d.PCIID == "" && d.DeviceID != b.DeviceID {
|
||||
return UniqueDevice
|
||||
}
|
||||
if a.Library == b.Library {
|
||||
if d.Library == b.Library {
|
||||
return SameBackendDevice
|
||||
}
|
||||
return DuplicateDevice
|
||||
@@ -453,8 +447,8 @@ func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||
|
||||
// For a SameBackendDevice, return true if b is better than a
|
||||
// e.g. newer GPU library version
|
||||
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||
aLib := a.LibraryPath[len(a.LibraryPath)-1]
|
||||
func (d DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||
aLib := d.LibraryPath[len(d.LibraryPath)-1]
|
||||
bLib := b.LibraryPath[len(b.LibraryPath)-1]
|
||||
if aLib == bLib {
|
||||
return false
|
||||
@@ -481,7 +475,7 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && (gpu.ComputeMajor != 7 || gpu.ComputeMinor != 2)) ||
|
||||
gpu.Library == "ROCm" ||
|
||||
gpu.Library == "Vulkan"
|
||||
|
||||
@@ -549,12 +543,12 @@ func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string) {
|
||||
}
|
||||
v, existing := env[envVar]
|
||||
if existing {
|
||||
v = v + ","
|
||||
v += ","
|
||||
}
|
||||
if d.FilterID != "" {
|
||||
v = v + d.FilterID
|
||||
v += d.FilterID
|
||||
} else {
|
||||
v = v + d.ID
|
||||
v += d.ID
|
||||
}
|
||||
env[envVar] = v
|
||||
}
|
||||
@@ -594,7 +588,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
||||
return nil, errors.New("failed to finish discovery before timeout")
|
||||
case <-tick:
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||
if err != nil {
|
||||
@@ -606,7 +600,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
|
||||
if err != nil {
|
||||
// slog.Warn("failed to send request", "error", err)
|
||||
if runner.HasExited() {
|
||||
return nil, fmt.Errorf("runner crashed")
|
||||
return nil, errors.New("runner crashed")
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -614,7 +608,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// old runner, fall back to bootstrapping model
|
||||
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
||||
return nil, errors.New("llamarunner free vram reporting not supported")
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -143,9 +143,9 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
r += 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
r += 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
@@ -264,9 +264,9 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
r -= 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
r -= 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
|
||||
@@ -146,7 +146,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
func modelForArch(c fs.Config) (Model, error) {
|
||||
arch := c.Architecture()
|
||||
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||
arch = arch + "_embed"
|
||||
arch += "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
@@ -175,9 +175,10 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
tagsCopy = append(tagsCopy, parseTag(tag))
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||
switch {
|
||||
case tt == reflect.TypeFor[Base]():
|
||||
vv.Set(reflect.ValueOf(base))
|
||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||
case tt == reflect.TypeFor[ml.Tensor]():
|
||||
var fn func([]Tag, string, string) [][]string
|
||||
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
||||
if len(tags) > 0 {
|
||||
@@ -217,9 +218,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
||||
case tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface:
|
||||
setPointer(base, vv, tagsCopy)
|
||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||
case tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array:
|
||||
for i := range vv.Len() {
|
||||
vvv := vv.Index(i)
|
||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||
|
||||
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.attnKeyLen, m.ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -178,10 +178,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
if len(m.Layers) == gemma27BLayerCount {
|
||||
m.Options.largeModelScaling = true
|
||||
m.largeModelScaling = true
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
@@ -202,9 +202,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
// final logit softcap
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||
return hiddenState.Scale(ctx, float64(m.finalLogitSoftcap)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -96,15 +96,15 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, err := m.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
m.imageSize,
|
||||
m.imageSize,
|
||||
m.numChannels,
|
||||
)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
|
||||
@@ -111,12 +111,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.TextConfig.ropeLocalBase
|
||||
ropeBase := m.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
ropeBase = m.ropeGlobalBase
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.attnKeyLen, ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -166,7 +166,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
var except []int
|
||||
|
||||
@@ -53,7 +53,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
MultiModalProjector: newMultiModalProjector(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -109,12 +109,12 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, size, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, size, err := m.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
|
||||
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.numChannels)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
||||
|
||||
@@ -133,7 +133,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
|
||||
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
// Prepare position IDs for 2D rope
|
||||
positions := make([]int32, numPatches)
|
||||
|
||||
@@ -54,7 +54,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
encoderCache := kvcache.NewEncoderCache()
|
||||
encoderCache.SetConfig(ml.CacheConfig{})
|
||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, ratio, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, ratio, err := m.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -223,8 +223,8 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, cros
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
var decoderLayers []TextDecoderLayer
|
||||
for i := range c.Uint("block_count") {
|
||||
decoderLayers := make([]TextDecoderLayer, c.Uint("block_count"))
|
||||
for i := range decoderLayers {
|
||||
var textDecoderLayer TextDecoderLayer
|
||||
if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
|
||||
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
|
||||
@@ -232,7 +232,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
|
||||
}
|
||||
|
||||
decoderLayers = append(decoderLayers, textDecoderLayer)
|
||||
decoderLayers[i] = textDecoderLayer
|
||||
}
|
||||
|
||||
return &TextModel{
|
||||
|
||||
@@ -2,6 +2,7 @@ package qwen2
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
@@ -130,7 +131,7 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// This model currently only supports the gpt2 tokenizer
|
||||
if c.String("tokenizer.ggml.model") == "llama" {
|
||||
return nil, fmt.Errorf("unsupported tokenizer: llama")
|
||||
return nil, errors.New("unsupported tokenizer: llama")
|
||||
}
|
||||
// detect library/qwen model(s) which are incompatible
|
||||
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {
|
||||
|
||||
@@ -48,7 +48,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -59,14 +59,13 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, grid, err := m.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Calculate tensor dimensions
|
||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||
patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
||||
|
||||
@@ -228,7 +228,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||
|
||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
|
||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.numHeads)
|
||||
// Apply encoder layers
|
||||
for i, layer := range m.Layers {
|
||||
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
||||
|
||||
@@ -107,7 +107,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
|
||||
return nil, nil, fmt.Errorf("failed to create patches: %w", err)
|
||||
}
|
||||
|
||||
// Return patches and grid dimensions
|
||||
|
||||
@@ -203,7 +203,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
|
||||
@@ -111,7 +111,7 @@ func (p *ImageProcessor) ProcessImage(ctx ml.Context, img image.Image) (ml.Tenso
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
|
||||
return nil, nil, fmt.Errorf("failed to create patches: %w", err)
|
||||
}
|
||||
|
||||
patchDim := p.numChannels * p.temporalPatchSize *
|
||||
|
||||
@@ -98,7 +98,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
if multiStepTool && message.Role == "user" {
|
||||
// Check if content starts with <tool_response> and ends with </tool_response>
|
||||
content := r.renderContent(message)
|
||||
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
||||
if !strings.HasPrefix(content, "<tool_response>") || !strings.HasSuffix(content, "</tool_response>") {
|
||||
multiStepTool = false
|
||||
lastQueryIndex = i
|
||||
}
|
||||
|
||||
@@ -205,12 +205,12 @@ func (q queue) Less(i, j int) bool {
|
||||
|
||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||
|
||||
func (q *queue) Push(x interface{}) {
|
||||
func (q *queue) Push(x any) {
|
||||
item := x.(*candidate)
|
||||
*q = append(*q, item)
|
||||
}
|
||||
|
||||
func (q *queue) Pop() interface{} {
|
||||
func (q *queue) Pop() any {
|
||||
old := *q
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
@@ -231,7 +231,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||
return "", fmt.Errorf("failed to parse hex byte: %w", err)
|
||||
}
|
||||
|
||||
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||
|
||||
@@ -232,9 +232,9 @@ func NewError(code int, message string) ErrorResponse {
|
||||
// ToUsage converts an api.ChatResponse to Usage
|
||||
func ToUsage(r api.ChatResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.Metrics.PromptEvalCount,
|
||||
CompletionTokens: r.Metrics.EvalCount,
|
||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,9 +326,9 @@ func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
|
||||
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
||||
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.Metrics.PromptEvalCount,
|
||||
CompletionTokens: r.Metrics.EvalCount,
|
||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,20 +377,19 @@ func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||
|
||||
// ToListCompletion converts an api.ListResponse to ListCompletion
|
||||
func ToListCompletion(r api.ListResponse) ListCompletion {
|
||||
var data []Model
|
||||
for _, m := range r.Models {
|
||||
data = append(data, Model{
|
||||
Id: m.Name,
|
||||
Object: "model",
|
||||
Created: m.ModifiedAt.Unix(),
|
||||
OwnedBy: model.ParseName(m.Name).Namespace,
|
||||
})
|
||||
}
|
||||
|
||||
return ListCompletion{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
c := ListCompletion{Object: "list"}
|
||||
if len(r.Models) > 0 {
|
||||
c.Data = make([]Model, len(r.Models))
|
||||
for i, m := range r.Models {
|
||||
c.Data[i] = Model{
|
||||
Id: m.Name,
|
||||
Object: "model",
|
||||
Created: m.ModifiedAt.Unix(),
|
||||
OwnedBy: model.ParseName(m.Name).Namespace,
|
||||
}
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
||||
@@ -487,19 +486,14 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
valid := false
|
||||
// support blank mime type to match api/chat taking just unadorned base64
|
||||
if strings.HasPrefix(url, "data:;base64,") {
|
||||
url = strings.TrimPrefix(url, "data:;base64,")
|
||||
valid = true
|
||||
}
|
||||
for _, t := range types {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
if strings.HasPrefix(url, prefix) {
|
||||
url = strings.TrimPrefix(url, prefix)
|
||||
valid = true
|
||||
break
|
||||
url, valid := strings.CutPrefix(url, "data:;base64,")
|
||||
if !valid {
|
||||
for _, t := range []string{"jpeg", "jpg", "png", "webp"} {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
url, valid = strings.CutPrefix(url, prefix)
|
||||
if valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
@@ -78,9 +79,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
if req.Files == nil {
|
||||
req.Files = digestMap
|
||||
} else {
|
||||
for k, v := range digestMap {
|
||||
req.Files[k] = v
|
||||
}
|
||||
maps.Copy(req.Files, digestMap)
|
||||
}
|
||||
case "adapter":
|
||||
path, err := expandPath(c.Args, relativeDir)
|
||||
@@ -371,7 +370,7 @@ func (e *ParserError) Error() string {
|
||||
func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||
var cmd Command
|
||||
var curr state
|
||||
var currLine int = 1
|
||||
currLine := 1
|
||||
var b bytes.Buffer
|
||||
var role string
|
||||
|
||||
|
||||
@@ -326,17 +326,11 @@ MESSAGE system`,
|
||||
return
|
||||
}
|
||||
|
||||
switch tt.err.(type) {
|
||||
case *ParserError:
|
||||
var pErr *ParserError
|
||||
if errors.As(err, &pErr) {
|
||||
// got the correct type of error
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, tt.err) {
|
||||
return
|
||||
} else if pErr := (*ParserError)(nil); errors.As(err, &pErr) {
|
||||
// got the correct type of error
|
||||
return
|
||||
}
|
||||
|
||||
t.Fatalf("unexpected error: expected: %v, actual: %v", tt.err, err)
|
||||
@@ -1089,7 +1083,7 @@ func TestFilesForModel(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Error("Expected error, but got none")
|
||||
}
|
||||
if tt.expectErrType != nil && err != tt.expectErrType {
|
||||
if tt.expectErrType != nil && !errors.Is(err, tt.expectErrType) {
|
||||
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
|
||||
}
|
||||
return
|
||||
|
||||
@@ -3,6 +3,7 @@ package readline
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/emirpasic/gods/v2/lists/arraylist"
|
||||
"github.com/mattn/go-runewidth"
|
||||
@@ -297,7 +298,7 @@ func (b *Buffer) drawRemaining() {
|
||||
remaining := (remainingText[len(currLine):])
|
||||
var totalLines int
|
||||
var displayLength int
|
||||
var lineLength int = currLineSpace
|
||||
lineLength := currLineSpace
|
||||
|
||||
for _, c := range remaining {
|
||||
if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth {
|
||||
@@ -515,13 +516,13 @@ func (b *Buffer) StringN(n int) string {
|
||||
}
|
||||
|
||||
func (b *Buffer) StringNM(n, m int) string {
|
||||
var s string
|
||||
var sb strings.Builder
|
||||
if m == 0 {
|
||||
m = b.Buf.Size()
|
||||
}
|
||||
for cnt := n; cnt < m; cnt++ {
|
||||
c, _ := b.Buf.Get(cnt)
|
||||
s += string(c)
|
||||
sb.WriteRune(c)
|
||||
}
|
||||
return s
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -124,18 +125,19 @@ func (i *Instance) Readline() (string, error) {
|
||||
case KeyRight:
|
||||
buf.MoveRight()
|
||||
case CharBracketedPaste:
|
||||
var code string
|
||||
var code strings.Builder
|
||||
for range 3 {
|
||||
r, err = i.Terminal.Read()
|
||||
if err != nil {
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
code += string(r)
|
||||
code.WriteRune(r)
|
||||
}
|
||||
if code == CharBracketedPasteStart {
|
||||
switch code.String() {
|
||||
case CharBracketedPasteStart:
|
||||
i.Pasting = true
|
||||
} else if code == CharBracketedPasteEnd {
|
||||
case CharBracketedPasteEnd:
|
||||
i.Pasting = false
|
||||
}
|
||||
case KeyDel:
|
||||
|
||||
@@ -459,10 +459,7 @@ func TestLogprobsWithStopSequences(t *testing.T) {
|
||||
|
||||
origLogprobsLen := len(logprobs)
|
||||
numTokensRemoved := origLen - newLen
|
||||
newLogprobsLen := origLogprobsLen - numTokensRemoved
|
||||
if newLogprobsLen < 0 {
|
||||
newLogprobsLen = 0
|
||||
}
|
||||
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
|
||||
logprobs = logprobs[:newLogprobsLen]
|
||||
|
||||
// Verify responses were truncated correctly
|
||||
|
||||
@@ -39,21 +39,15 @@ func TruncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
|
||||
joined = joined[:index]
|
||||
|
||||
// Split truncated string back into pieces of original lengths
|
||||
lengths := make([]int, len(pieces))
|
||||
for i, piece := range pieces {
|
||||
lengths[i] = len(piece)
|
||||
}
|
||||
|
||||
var result []string
|
||||
result := make([]string, 0, len(pieces))
|
||||
tokenTruncated := false
|
||||
start := 0
|
||||
for _, length := range lengths {
|
||||
for _, piece := range pieces {
|
||||
if start >= len(joined) {
|
||||
break
|
||||
}
|
||||
|
||||
end := start + length
|
||||
end := start + len(piece)
|
||||
if end > len(joined) {
|
||||
end = len(joined)
|
||||
tokenTruncated = true
|
||||
|
||||
@@ -61,7 +61,7 @@ func (c *ImageContext) MultimodalTokenize(llamaContext *llama.Context, data []by
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(data) <= 0 {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("received zero length image")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -18,7 +19,7 @@ func TestImageCache(t *testing.T) {
|
||||
|
||||
// Empty cache
|
||||
result, err := cache.findImage(0x5adb61d31933a946)
|
||||
if err != errImageNotFound {
|
||||
if !errors.Is(err, errImageNotFound) {
|
||||
t.Errorf("found result in empty cache: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -577,10 +577,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
if seq.logprobs {
|
||||
origLogprobsLen := len(seq.pendingLogprobs)
|
||||
numTokensRemoved := origLen - newLen
|
||||
newLogprobsLen := origLogprobsLen - numTokensRemoved
|
||||
if newLogprobsLen < 0 {
|
||||
newLogprobsLen = 0
|
||||
}
|
||||
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
|
||||
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
|
||||
}
|
||||
|
||||
@@ -998,7 +995,6 @@ func Execute(args []string) error {
|
||||
|
||||
log.Println("Server listening on", addr)
|
||||
if err := httpServer.Serve(listener); err != nil {
|
||||
log.Fatal("server error:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package ollamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -511,7 +510,7 @@ type mockCache struct {
|
||||
// Implement only the methods needed for the test
|
||||
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if m.shouldFail {
|
||||
return fmt.Errorf("mock cache removal error")
|
||||
return errors.New("mock cache removal error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -801,10 +801,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
if seq.logprobs {
|
||||
origLogprobsLen := len(seq.pendingLogprobs)
|
||||
numTokensRemoved := origLen - newLen
|
||||
newLogprobsLen := origLogprobsLen - numTokensRemoved
|
||||
if newLogprobsLen < 0 {
|
||||
newLogprobsLen = 0
|
||||
}
|
||||
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
|
||||
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
|
||||
}
|
||||
|
||||
@@ -1242,7 +1239,7 @@ func (s *Server) loadModel() {
|
||||
s.progress = progress
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to load model: %v", err))
|
||||
panic(fmt.Errorf("failed to load model: %w", err))
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
@@ -1432,7 +1429,6 @@ func Execute(args []string) error {
|
||||
|
||||
log.Println("Server listening on", addr)
|
||||
if err := httpServer.Serve(listener); err != nil {
|
||||
log.Fatal("server error:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
temp = max(temp, 1e-7)
|
||||
for i := range ts {
|
||||
ts[i].value = ts[i].value / temp
|
||||
ts[i].value /= temp
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||
|
||||
values := redirectURL.Query()
|
||||
values.Add("service", r.Service)
|
||||
for _, s := range strings.Split(r.Scope, " ") {
|
||||
for s := range strings.SplitSeq(r.Scope, " ") {
|
||||
values.Add("scope", s)
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
data := fmt.Appendf(nil, "%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))
|
||||
|
||||
headers := make(http.Header)
|
||||
signature, err := auth.Sign(ctx, data)
|
||||
@@ -75,7 +75,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%d: %v", response.StatusCode, err)
|
||||
return "", fmt.Errorf("%d: %w", response.StatusCode, err)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
|
||||
@@ -386,7 +386,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
}
|
||||
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
// Path is likely outside the root
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
return nil, fmt.Errorf("%w: %w: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
@@ -456,15 +456,15 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||
return l.KV(), nil
|
||||
}
|
||||
}
|
||||
return ggml.KV{}, fmt.Errorf("no base model was found")
|
||||
return ggml.KV{}, errors.New("no base model was found")
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []Layer
|
||||
for _, layer := range baseLayers {
|
||||
layers := make([]Layer, len(baseLayers))
|
||||
for i, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
if quantType != "" && layer.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
want, err := ggml.ParseFileType(quantType)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -480,13 +480,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
}
|
||||
}
|
||||
}
|
||||
config.ModelFormat = cmp.Or(config.ModelFormat, layer.GGML.Name())
|
||||
config.ModelFormat = cmp.Or(config.ModelFormat, layer.Name())
|
||||
config.ModelFamily = cmp.Or(config.ModelFamily, layer.GGML.KV().Architecture())
|
||||
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
|
||||
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
|
||||
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
|
||||
}
|
||||
layers = append(layers, layer.Layer)
|
||||
layers[i] = layer.Layer
|
||||
}
|
||||
|
||||
if r.Template != "" {
|
||||
@@ -678,10 +678,10 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
return nil, fmt.Errorf("%w: %w", errBadTemplate, err)
|
||||
}
|
||||
if _, err := template.Parse(t); err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
return nil, fmt.Errorf("%w: %w", errBadTemplate, err)
|
||||
}
|
||||
|
||||
blob := strings.NewReader(t)
|
||||
|
||||
@@ -640,7 +640,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
|
||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
return fmt.Errorf("pull model manifest: %w", err)
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
@@ -786,7 +786,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||
return nil, fmt.Errorf("%d: %w", resp.StatusCode, err)
|
||||
}
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||
default:
|
||||
|
||||
2
server/internal/cache/blob/cache.go
vendored
2
server/internal/cache/blob/cache.go
vendored
@@ -438,7 +438,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
|
||||
// last write. check hash.
|
||||
sum := w.h.Sum(nil)
|
||||
if !bytes.Equal(sum, w.d.sum[:]) {
|
||||
return 0, w.seterr(fmt.Errorf("file content changed underfoot"))
|
||||
return 0, w.seterr(errors.New("file content changed underfoot"))
|
||||
}
|
||||
if w.testHookBeforeFinalWrite != nil {
|
||||
w.testHookBeforeFinalWrite(w.f)
|
||||
|
||||
3
server/internal/cache/blob/casecheck_test.go
vendored
3
server/internal/cache/blob/casecheck_test.go
vendored
@@ -84,8 +84,7 @@ func useCaseInsensitiveTempDir(t *testing.T) bool {
|
||||
|
||||
// TODO(bmizerany): Print platform-specific instructions or
|
||||
// link to docs on that topic.
|
||||
lines := strings.Split(volumeHint, "\n")
|
||||
for _, line := range lines {
|
||||
for line := range strings.SplitSeq(volumeHint, "\n") {
|
||||
t.Skip(line)
|
||||
}
|
||||
}
|
||||
|
||||
2
server/internal/cache/blob/digest.go
vendored
2
server/internal/cache/blob/digest.go
vendored
@@ -60,7 +60,7 @@ func (d Digest) String() string {
|
||||
}
|
||||
|
||||
func (d Digest) Short() string {
|
||||
return fmt.Sprintf("%x", d.sum[:4])
|
||||
return hex.EncodeToString(d.sum[:4])
|
||||
}
|
||||
|
||||
func (d Digest) Sum() [32]byte {
|
||||
|
||||
@@ -1184,11 +1184,11 @@ func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
||||
}
|
||||
start, err := strconv.ParseInt(startPart, 10, 64)
|
||||
if err != nil {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %w", s, err)
|
||||
}
|
||||
end, err := strconv.ParseInt(endPart, 10, 64)
|
||||
if err != nil {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %w", s, err)
|
||||
}
|
||||
if start > end {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
||||
|
||||
@@ -142,7 +142,7 @@ var junkName Name
|
||||
|
||||
func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
junkName = Parse("h/n/m:t")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,15 +187,15 @@ func (w *relayWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *relayWriter) awaitTurn() (ok bool) {
|
||||
if t.ready {
|
||||
func (w *relayWriter) awaitTurn() (ok bool) {
|
||||
if w.ready {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-t.t.Ready():
|
||||
t.ready = true
|
||||
case <-w.t.Ready():
|
||||
w.ready = true
|
||||
return true
|
||||
case <-t.q.closed():
|
||||
case <-w.q.closed():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,7 +251,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
type progressUpdateJSON struct {
|
||||
Error string `json:"error,omitempty,omitzero"`
|
||||
Status string `json:"status,omitempty,omitzero"`
|
||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||
Digest blob.Digest `json:"digest,omitzero"`
|
||||
Total int64 `json:"total,omitempty,omitzero"`
|
||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
||||
}
|
||||
|
||||
var prefix string
|
||||
var prefix strings.Builder
|
||||
prompt := msg.Content
|
||||
|
||||
for _, i := range msg.Images {
|
||||
@@ -85,14 +85,14 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
||||
if !strings.Contains(prompt, "[img]") {
|
||||
prefix += imgTag
|
||||
prefix.WriteString(imgTag)
|
||||
} else {
|
||||
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
||||
}
|
||||
|
||||
images = append(images, imgData)
|
||||
}
|
||||
msgs[currMsgIdx+cnt].Content = prefix + prompt
|
||||
msgs[currMsgIdx+cnt].Content = prefix.String() + prompt
|
||||
}
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -238,7 +239,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
} else if tt.error != nil && !errors.Is(err, tt.error) {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||
data, err := io.ReadAll(sr)
|
||||
if err != nil {
|
||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err)
|
||||
}
|
||||
var f32s []float32
|
||||
newType := fsggml.TensorType(q.to.Kind)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -129,7 +130,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
}
|
||||
|
||||
if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 {
|
||||
return nil, nil, nil, fmt.Errorf("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
|
||||
return nil, nil, nil, errors.New("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
|
||||
}
|
||||
|
||||
if err := model.CheckCapabilities(caps...); err != nil {
|
||||
@@ -361,11 +362,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
} else if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
@@ -649,10 +648,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
truncate := true
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
|
||||
var input []string
|
||||
|
||||
@@ -825,9 +821,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var e []float64
|
||||
for _, v := range embedding {
|
||||
e = append(e, float64(v))
|
||||
e := make([]float64, len(embedding))
|
||||
for i, v := range embedding {
|
||||
e[i] = float64(v)
|
||||
}
|
||||
|
||||
resp := api.EmbeddingResponse{
|
||||
@@ -1139,9 +1135,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
if m.Options == nil {
|
||||
m.Options = make(map[string]any)
|
||||
}
|
||||
for k, v := range req.Options {
|
||||
m.Options[k] = v
|
||||
}
|
||||
maps.Copy(m.Options, req.Options)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
@@ -1212,7 +1206,7 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
models := []api.ListModelResponse{}
|
||||
models := make([]api.ListModelResponse, 0, len(ms))
|
||||
for n, m := range ms {
|
||||
var cf ConfigV2
|
||||
|
||||
@@ -1811,13 +1805,13 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
ExpiresAt: v.expiresAt,
|
||||
}
|
||||
if v.Options != nil {
|
||||
mr.ContextLength = v.Options.NumCtx
|
||||
mr.ContextLength = v.NumCtx
|
||||
}
|
||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||
// possible that it will be set to the unix epoch. For those cases, just
|
||||
// calculate the time w/ the sessionDuration instead.
|
||||
var epoch time.Time
|
||||
if v.expiresAt == epoch {
|
||||
if v.expiresAt.Equal(epoch) {
|
||||
mr.ExpiresAt = time.Now().Add(v.sessionDuration)
|
||||
}
|
||||
|
||||
@@ -2000,11 +1994,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
} else if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
|
||||
@@ -196,11 +196,9 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||
}
|
||||
} else {
|
||||
} else if w.Code != http.StatusOK {
|
||||
// When debug is disabled, it should attempt normal processing
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -401,11 +399,9 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||
}
|
||||
} else {
|
||||
} else if w.Code != http.StatusOK {
|
||||
// When debug is disabled, it should attempt normal processing
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
mock.CompletionResponse.Content = "Hi!"
|
||||
mock.Content = "Hi!"
|
||||
|
||||
t.Run("chat-like flow uses renderer", func(t *testing.T) {
|
||||
// Test that when using messages (chat-like flow), the built-in renderer is used
|
||||
@@ -109,12 +109,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
|
||||
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
|
||||
// When messages are built internally from prompt, it should use the renderer
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
||||
if !strings.Contains(mock.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.Prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") {
|
||||
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
||||
if !strings.Contains(mock.Prompt, "<|im_end|>") {
|
||||
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.Prompt)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -132,12 +132,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain the system message and use renderer format
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") {
|
||||
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt)
|
||||
if !strings.Contains(mock.Prompt, "<|im_start|>system") {
|
||||
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.Prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") {
|
||||
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt)
|
||||
if !strings.Contains(mock.Prompt, "You are a helpful coding assistant.") {
|
||||
t.Errorf("expected prompt to contain system message content, got: %s", mock.Prompt)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -155,12 +155,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should NOT use the renderer format when custom template is provided
|
||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt)
|
||||
if strings.Contains(mock.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.Prompt)
|
||||
}
|
||||
|
||||
// Should just be the raw prompt from the template
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "Write a hello world function"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -191,12 +191,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should NOT use the renderer format when suffix is provided
|
||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
|
||||
if strings.Contains(mock.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.Prompt)
|
||||
}
|
||||
|
||||
// Should use the suffix template format
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -41,7 +41,7 @@ func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn
|
||||
}
|
||||
|
||||
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||
for range strings.Fields(s) {
|
||||
for range strings.FieldsSeq(s) {
|
||||
tokens = append(tokens, len(tokens))
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
mock.CompletionResponse.Content = "Hi!"
|
||||
mock.Content = "Hi!"
|
||||
t.Run("messages", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
@@ -392,7 +392,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "user: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -422,14 +422,14 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
||||
})
|
||||
|
||||
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||
mock.Content = "Abra kadabra!"
|
||||
t.Run("messages with system", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
@@ -444,7 +444,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -467,7 +467,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -985,7 +985,7 @@ func TestGenerate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
mock.CompletionResponse.Content = "Hi!"
|
||||
mock.Content = "Hi!"
|
||||
t.Run("prompt", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
@@ -997,7 +997,7 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -1025,14 +1025,14 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
||||
})
|
||||
|
||||
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||
mock.Content = "Abra kadabra!"
|
||||
t.Run("prompt with system", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-system",
|
||||
@@ -1045,7 +1045,7 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -1067,7 +1067,7 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -1097,7 +1097,7 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -1112,7 +1112,7 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "def add("); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -1129,7 +1129,7 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
||||
if diff := cmp.Diff(mock.Prompt, "Help me write tests."); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user