mirror of
https://github.com/ollama/ollama.git
synced 2026-01-03 21:19:20 -05:00
Compare commits
24 Commits
parth/add-
...
grace/deep
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e1878e6e33 | ||
|
|
f0733c13b5 | ||
|
|
07162c509f | ||
|
|
5be8277683 | ||
|
|
ec65cc3690 | ||
|
|
e3731fb160 | ||
|
|
8dbc9e7b68 | ||
|
|
abe67acf8a | ||
|
|
4ff8a691bc | ||
|
|
1b308e1d2a | ||
|
|
bd6c1d6b49 | ||
|
|
3af5d3b738 | ||
|
|
7730895158 | ||
|
|
de9ecfd01c | ||
|
|
95fdd8d619 | ||
|
|
9f7822851c | ||
|
|
9b2035d194 | ||
|
|
93d45d7a04 | ||
|
|
709f842457 | ||
|
|
2dfb74410d | ||
|
|
1eb5e75972 | ||
|
|
3475d915cb | ||
|
|
48e78e9be1 | ||
|
|
a838421ea3 |
@@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
||||
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
||||
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
||||
- [Swollama for Swift]([https://github.com/marcusziade/Swollama](https://github.com/guitaripod/Swollama) with [DocC]( https://guitaripod.github.io/Swollama/documentation/swollama)
|
||||
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
|
||||
- [GoLamify](https://github.com/prasad89/golamify)
|
||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||
|
||||
@@ -347,7 +347,7 @@ type CreateProgressFunc func(ProgressResponse) error
|
||||
// Create creates a model from a [Modelfile]. fn is a progress function that
|
||||
// behaves similarly to other methods (see [Client.Pull]).
|
||||
//
|
||||
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
|
||||
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
|
||||
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
||||
var resp ProgressResponse
|
||||
|
||||
@@ -191,13 +191,6 @@ func LaunchNewApp() {
|
||||
C.launchApp(appName)
|
||||
}
|
||||
|
||||
// Send a request to the main app thread to load a UI page
|
||||
func sendUIRequestMessage(path string) {
|
||||
p := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(p))
|
||||
C.uiRequest(p)
|
||||
}
|
||||
|
||||
func registerLaunchAgent(hasCompletedFirstRun bool) {
|
||||
// Remove any stale Login Item registrations
|
||||
C.unregisterSelfFromLoginItem()
|
||||
|
||||
@@ -263,11 +263,6 @@ func createLoginShortcut() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send a request to the main app thread to load a UI page
|
||||
func sendUIRequestMessage(path string) {
|
||||
wintray.SendUIRequestMessage(path)
|
||||
}
|
||||
|
||||
func LaunchNewApp() {
|
||||
}
|
||||
|
||||
|
||||
@@ -169,37 +169,47 @@ DlgResult fileDlg(FileDlgParams* params) {
|
||||
}
|
||||
|
||||
NSArray* urls = [panel URLs];
|
||||
if(self->params->allowMultiple && [urls count] >= 1) {
|
||||
if([urls count] == 0) {
|
||||
return DLG_CANCEL;
|
||||
}
|
||||
|
||||
if(self->params->allowMultiple) {
|
||||
// For multiple files, we need to return all paths separated by null bytes
|
||||
char* bufPtr = self->params->buf;
|
||||
int remainingBuf = self->params->nbuf;
|
||||
|
||||
// Calculate total required buffer size first
|
||||
int totalSize = 0;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
||||
}
|
||||
totalSize += 1; // Final null terminator
|
||||
// Calculate total required buffer size first
|
||||
int totalSize = 0;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
||||
}
|
||||
totalSize += 1; // Final null terminator
|
||||
|
||||
if(totalSize > self->params->nbuf) {
|
||||
// Not enough buffer space
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
if(totalSize > self->params->nbuf) {
|
||||
// Not enough buffer space
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
|
||||
// Now actually copy the paths (we know we have space)
|
||||
bufPtr = self->params->buf;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
||||
int pathLen = strlen(tempBuf);
|
||||
strcpy(bufPtr, tempBuf);
|
||||
bufPtr += pathLen + 1;
|
||||
}
|
||||
*bufPtr = '\0'; // Final null terminator
|
||||
// Now actually copy the paths (we know we have space)
|
||||
bufPtr = self->params->buf;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
||||
int pathLen = strlen(tempBuf);
|
||||
strcpy(bufPtr, tempBuf);
|
||||
bufPtr += pathLen + 1;
|
||||
}
|
||||
*bufPtr = '\0'; // Final null terminator
|
||||
} else {
|
||||
// Single file/directory selection - write path to buffer
|
||||
NSURL* url = [urls firstObject];
|
||||
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
}
|
||||
|
||||
return DLG_OK;
|
||||
|
||||
@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
|
||||
type WinDlgError int
|
||||
|
||||
func (e WinDlgError) Error() string {
|
||||
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
|
||||
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
|
||||
}
|
||||
|
||||
func err() error {
|
||||
|
||||
@@ -224,9 +224,7 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
||||
if _, err := os.Stat(settings.Models); err == nil {
|
||||
env["OLLAMA_MODELS"] = settings.Models
|
||||
} else {
|
||||
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
|
||||
settings.Models = ""
|
||||
s.store.SetSettings(settings)
|
||||
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
|
||||
}
|
||||
}
|
||||
if settings.ContextLength > 0 {
|
||||
|
||||
82
app/ui/ui.go
82
app/ui/ui.go
@@ -12,13 +12,13 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -117,40 +117,66 @@ func (s *Server) log() *slog.Logger {
|
||||
|
||||
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
||||
func (s *Server) ollamaProxy() http.Handler {
|
||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||
if ollamaHost == "" {
|
||||
ollamaHost = "http://127.0.0.1:11434"
|
||||
}
|
||||
var (
|
||||
proxy http.Handler
|
||||
proxyMu sync.Mutex
|
||||
)
|
||||
|
||||
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
|
||||
ollamaHost = "http://" + ollamaHost
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxyMu.Lock()
|
||||
p := proxy
|
||||
proxyMu.Unlock()
|
||||
|
||||
target, err := url.Parse(ollamaHost)
|
||||
if err != nil {
|
||||
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
|
||||
})
|
||||
}
|
||||
if p == nil {
|
||||
proxyMu.Lock()
|
||||
if proxy == nil {
|
||||
var err error
|
||||
for i := range 2 {
|
||||
if i > 0 {
|
||||
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||
err = WaitForServer(context.Background(), 10*time.Second)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
if err != nil {
|
||||
proxyMu.Unlock()
|
||||
s.log().Error("ollama server not ready after retries", "error", err)
|
||||
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
originalDirector := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = target.Host
|
||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||
}
|
||||
target := envconfig.Host()
|
||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
return proxy
|
||||
originalDirector := newProxy.Director
|
||||
newProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = target.Host
|
||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||
}
|
||||
|
||||
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
|
||||
proxy = newProxy
|
||||
p = newProxy
|
||||
} else {
|
||||
p = proxy
|
||||
}
|
||||
proxyMu.Unlock()
|
||||
}
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
||||
|
||||
@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
|
||||
case uint32(UI_REQUEST_MSG_ID):
|
||||
// Requests for the UI must always come from the main event thread
|
||||
l := int(wParam)
|
||||
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
|
||||
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
|
||||
t.app.UIRun(path)
|
||||
case WM_COPYDATA:
|
||||
// Handle URL scheme requests from other instances
|
||||
if lParam != 0 {
|
||||
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
|
||||
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
||||
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
|
||||
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
||||
// Convert the data back to string
|
||||
data := make([]byte, cds.CbData)
|
||||
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
|
||||
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
|
||||
urlScheme := string(data)
|
||||
handleURLSchemeRequest(urlScheme)
|
||||
lResult = 1 // Return non-zero to indicate success
|
||||
|
||||
@@ -182,6 +182,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &llama4Model{}
|
||||
case "Mistral3ForConditionalGeneration":
|
||||
conv = &mistral3Model{}
|
||||
case "Ministral3ForCausalLM":
|
||||
conv = &mistral3CausalModel{}
|
||||
case "MixtralForCausalLM":
|
||||
conv = &mixtralModel{}
|
||||
case "GemmaForCausalLM":
|
||||
|
||||
@@ -30,13 +30,15 @@ type mistral3Model struct {
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
ScalingBeta float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
Mscale *float32 `json:"mscale"`
|
||||
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
@@ -50,6 +52,9 @@ type mistral3Model struct {
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"vision_config"`
|
||||
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
@@ -72,10 +77,22 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
||||
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
|
||||
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
|
||||
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
|
||||
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
|
||||
|
||||
if p.TextModel.RopeParameters.Mscale != nil {
|
||||
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
|
||||
}
|
||||
if p.TextModel.RopeParameters.MscaleAllDim != nil {
|
||||
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
|
||||
}
|
||||
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
||||
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
|
||||
}
|
||||
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
|
||||
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
|
||||
// Vision configuration
|
||||
@@ -88,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
|
||||
|
||||
// Multimodal configuration
|
||||
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||
|
||||
181
convert/convert_mistral_causal.go
Normal file
181
convert/convert_mistral_causal.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type mistral3CausalModel struct {
|
||||
ModelParameters
|
||||
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
Mscale *float32 `json:"mscale"`
|
||||
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
// Text configuration
|
||||
kv["mistral3.block_count"] = p.NumHiddenLayers
|
||||
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["mistral3.embedding_length"] = p.HiddenSize
|
||||
kv["mistral3.feed_forward_length"] = p.IntermediateSize
|
||||
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
|
||||
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
|
||||
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
|
||||
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
|
||||
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
|
||||
|
||||
if p.RopeParameters.Mscale != nil {
|
||||
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
|
||||
}
|
||||
|
||||
if p.RopeParameters.MscaleAllDim != nil {
|
||||
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
|
||||
}
|
||||
|
||||
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
|
||||
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
|
||||
if p.RopeParameters.Llama4ScalingBeta != nil {
|
||||
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") {
|
||||
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.norm", "output_norm",
|
||||
"model.", "",
|
||||
"layers", "blk",
|
||||
"transformer.layers", "blk",
|
||||
"vision_tower", "v",
|
||||
"ln_pre", "encoder_norm",
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"embed_tokens", "token_embd",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"attention.q_proj", "attn_q",
|
||||
"attention.k_proj", "attn_k",
|
||||
"attention.v_proj", "attn_v",
|
||||
"attention.o_proj", "attn_output",
|
||||
"attention_norm", "attn_norm",
|
||||
"feed_forward.gate_proj", "ffn_gate",
|
||||
"feed_forward.down_proj", "ffn_down",
|
||||
"feed_forward.up_proj", "ffn_up",
|
||||
"multi_modal_projector", "mm",
|
||||
"ffn_norm", "ffn_norm",
|
||||
"lm_head", "output",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
var dims []int
|
||||
for _, dim := range shape {
|
||||
dims = append(dims, int(dim))
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||
heads = p.NumAttentionHeads
|
||||
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||
}
|
||||
|
||||
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(dims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
@@ -37,10 +37,6 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n <= 0 || n > 100<<20 {
|
||||
return nil, fmt.Errorf("invalid safetensors file %q (header size: %d): file may be corrupted or a Git LFS pointer", p, n)
|
||||
}
|
||||
|
||||
b := bytes.NewBuffer(make([]byte, 0, n))
|
||||
if _, err = io.CopyN(b, f, n); err != nil {
|
||||
return nil, err
|
||||
|
||||
10
docs/api.md
10
docs/api.md
@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
@@ -507,7 +507,7 @@ The `message` object has the following fields:
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
|
||||
- `template`: (optional) the prompt template for the model
|
||||
- `license`: (optional) a string or list of strings containing the license or licenses for the model
|
||||
- `system`: (optional) a string containing the system prompt for the model
|
||||
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
|
||||
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
|
||||
- `messages`: (optional) a list of message objects used to create a conversation
|
||||
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
|
||||
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
|
||||
Advanced parameters:
|
||||
|
||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `dimensions`: number of dimensions for the embedding
|
||||
|
||||
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
|
||||
|
||||
Advanced parameters:
|
||||
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
### Examples
|
||||
|
||||
File diff suppressed because one or more lines are too long
46
docs/tools/extract-examples/README.md
Normal file
46
docs/tools/extract-examples/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# extract-examples
|
||||
|
||||
Extracts code examples from MDX files to a temp directory so you can run them.
|
||||
|
||||
## Usage
|
||||
|
||||
```shell
|
||||
go run docs/tools/extract-examples/main.go <mdx-file>
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
```shell
|
||||
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
|
||||
- 01_basic.py
|
||||
- 01_basic.js
|
||||
- 01_basic.sh
|
||||
- 02_responses.py
|
||||
- 02_responses.js
|
||||
- 02_responses.sh
|
||||
- 03_vision.py
|
||||
- 03_vision.js
|
||||
- 03_vision.sh
|
||||
|
||||
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
|
||||
To run examples:
|
||||
|
||||
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
npm install # for JS examples
|
||||
|
||||
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
|
||||
```
|
||||
|
||||
## How it works
|
||||
|
||||
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
|
||||
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
|
||||
- Writes all extracted files to a temp directory
|
||||
137
docs/tools/extract-examples/main.go
Normal file
137
docs/tools/extract-examples/main.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
mdxFile := os.Args[1]
|
||||
|
||||
f, err := os.Open(mdxFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create temp directory
|
||||
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
|
||||
|
||||
// Patterns
|
||||
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
|
||||
codeGroupStart := regexp.MustCompile("^<CodeGroup")
|
||||
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
inCodeBlock := false
|
||||
inCodeGroup := false
|
||||
var currentFile string
|
||||
var content strings.Builder
|
||||
count := 0
|
||||
codeGroupNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Track CodeGroup boundaries
|
||||
if codeGroupStart.MatchString(line) {
|
||||
inCodeGroup = true
|
||||
codeGroupNum++
|
||||
continue
|
||||
}
|
||||
if codeGroupEnd.MatchString(line) {
|
||||
inCodeGroup = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inCodeBlock {
|
||||
if line == "```" {
|
||||
// End of code block - write file
|
||||
if currentFile != "" {
|
||||
outPath := filepath.Join(tempDir, currentFile)
|
||||
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
|
||||
} else {
|
||||
fmt.Printf(" - %s\n", currentFile)
|
||||
count++
|
||||
}
|
||||
}
|
||||
inCodeBlock = false
|
||||
currentFile = ""
|
||||
content.Reset()
|
||||
} else {
|
||||
content.WriteString(line)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
|
||||
inCodeBlock = true
|
||||
filename := matches[2]
|
||||
// Prefix with CodeGroup number if inside a CodeGroup
|
||||
if inCodeGroup {
|
||||
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
|
||||
} else {
|
||||
currentFile = filename
|
||||
}
|
||||
content.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Write package.json for JavaScript dependencies
|
||||
packageJSON := `{
|
||||
"name": "mdx-examples",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"openai": "^4",
|
||||
"ollama": "^0.5"
|
||||
}
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
|
||||
}
|
||||
|
||||
// Write pyproject.toml for Python dependencies
|
||||
pyprojectTOML := `[project]
|
||||
name = "mdx-examples"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"openai",
|
||||
"ollama",
|
||||
]
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("To run examples:\n")
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type GGML struct {
|
||||
@@ -550,7 +551,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
@@ -791,7 +792,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention {
|
||||
if useFlashAttention == ml.FlashAttentionEnabled {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
@@ -809,6 +810,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
||||
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (f GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||
|
||||
@@ -487,6 +487,63 @@ func TestEmbedTruncation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
|
||||
func TestEmbedLargeInput(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
// Test with progressively larger inputs
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputWords int
|
||||
}{
|
||||
{"medium_input_256_words", 256},
|
||||
{"large_input_512_words", 512},
|
||||
{"very_large_input_800_words", 800},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
words := make([]string, tc.inputWords)
|
||||
for i := range words {
|
||||
words[i] = "word"
|
||||
}
|
||||
input := strings.Join(words, " ")
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: input,
|
||||
KeepAlive: &api.Duration{Duration: 30 * time.Second},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(mctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 1 {
|
||||
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if len(res.Embeddings[0]) == 0 {
|
||||
t.Fatal("expected non-empty embedding")
|
||||
}
|
||||
|
||||
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||
// properly preserve their HTTP status codes when returned to the client.
|
||||
// This test specifically checks the error handling path in EmbedHandler
|
||||
|
||||
@@ -118,18 +118,22 @@ type ContextParams struct {
|
||||
c C.struct_llama_context_params
|
||||
}
|
||||
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams {
|
||||
params := C.llama_context_default_params()
|
||||
params.n_ctx = C.uint(numCtx)
|
||||
params.n_batch = C.uint(batchSize)
|
||||
params.n_batch = C.uint(batchSize * numSeqMax)
|
||||
params.n_ubatch = C.uint(batchSize)
|
||||
params.n_seq_max = C.uint(numSeqMax)
|
||||
params.n_threads = C.int(threads)
|
||||
params.n_threads_batch = params.n_threads
|
||||
params.embeddings = C.bool(true)
|
||||
if flashAttention {
|
||||
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED
|
||||
} else {
|
||||
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED
|
||||
switch flashAttention {
|
||||
case ml.FlashAttentionEnabled:
|
||||
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED)
|
||||
case ml.FlashAttentionDisabled:
|
||||
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED)
|
||||
case ml.FlashAttentionAuto:
|
||||
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO)
|
||||
}
|
||||
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||
|
||||
@@ -188,6 +188,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
if len(projectors) > 0 && llamaModel != nil {
|
||||
loadRequest.ProjectorPath = projectors[0]
|
||||
}
|
||||
// Determine if the user has forced FA on or off
|
||||
faUserSet := false
|
||||
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
|
||||
faUserSet = true
|
||||
}
|
||||
|
||||
fa := envconfig.FlashAttention(f.FlashAttention())
|
||||
|
||||
@@ -205,19 +210,51 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if fa {
|
||||
slog.Info("enabling flash attention")
|
||||
loadRequest.FlashAttention = true
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
if textProcessor == nil {
|
||||
flashAttention := ml.FlashAttentionAuto
|
||||
if faUserSet {
|
||||
if fa {
|
||||
flashAttention = ml.FlashAttentionEnabled
|
||||
} else {
|
||||
flashAttention = ml.FlashAttentionDisabled
|
||||
}
|
||||
}
|
||||
|
||||
if kvct != "" {
|
||||
if f.KVCacheTypeIsQuantized(kvct) {
|
||||
if flashAttention != ml.FlashAttentionEnabled {
|
||||
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||
loadRequest.KvCacheType = ""
|
||||
} else if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||
}
|
||||
} else {
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRequest.FlashAttention = flashAttention
|
||||
} else {
|
||||
// For Ollama engine, use our SupportsFlashAttention logic
|
||||
if fa {
|
||||
slog.Info("enabling flash attention")
|
||||
loadRequest.FlashAttention = ml.FlashAttentionEnabled
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
}
|
||||
} else if kvct != "" && kvct != "f16" {
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
}
|
||||
} else if kvct != "" && kvct != "f16" {
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
}
|
||||
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
@@ -435,7 +472,7 @@ type LoadRequest struct {
|
||||
LoraPath []string
|
||||
Parallel int
|
||||
BatchSize int
|
||||
FlashAttention bool
|
||||
FlashAttention ml.FlashAttentionType
|
||||
KvSize int
|
||||
KvCacheType string
|
||||
NumThreads int
|
||||
@@ -474,6 +511,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||
}
|
||||
|
||||
// Check if embedding model and adjust batch size accordingly
|
||||
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
|
||||
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
|
||||
s.loadRequest.BatchSize = s.options.NumCtx
|
||||
slog.Info("embedding model detected, setting batch size to context length", "batch_size", s.loadRequest.BatchSize)
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
|
||||
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)
|
||||
|
||||
|
||||
@@ -433,3 +433,111 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type ResponsesWriter struct {
|
||||
BaseWriter
|
||||
converter *openai.ResponsesStreamConverter
|
||||
model string
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
chatReq, err := openai.FromResponsesRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if client requested streaming (defaults to false)
|
||||
streamRequested := req.Stream != nil && *req.Stream
|
||||
|
||||
// Pass streaming preference to the underlying chat request
|
||||
chatReq.Stream = &streamRequested
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
responseID := fmt.Sprintf("resp_%d", rand.Intn(999999))
|
||||
itemID := fmt.Sprintf("msg_%d", rand.Intn(999999))
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
if streamRequested {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ type BackendParams struct {
|
||||
GPULayers GPULayersList
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
FlashAttention FlashAttentionType
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
@@ -109,7 +109,7 @@ type Backend struct {
|
||||
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
|
||||
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
|
||||
|
||||
flashAttention bool
|
||||
flashAttention ml.FlashAttentionType
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||
maxGraphNodes int
|
||||
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
if b.flashAttention {
|
||||
if b.flashAttention == ml.FlashAttentionEnabled {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||
} else {
|
||||
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
|
||||
@@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
if t.b.flashAttention {
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
|
||||
26
ml/device.go
26
ml/device.go
@@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type FlashAttentionType int32
|
||||
|
||||
const (
|
||||
// Aligned with llama_flash_attn_type
|
||||
FlashAttentionAuto FlashAttentionType = -1
|
||||
FlashAttentionDisabled FlashAttentionType = 0
|
||||
FlashAttentionEnabled FlashAttentionType = 1
|
||||
)
|
||||
|
||||
func (f FlashAttentionType) LogValue() slog.Value {
|
||||
return slog.AnyValue(f.String())
|
||||
}
|
||||
|
||||
func (f FlashAttentionType) String() string {
|
||||
switch f {
|
||||
case FlashAttentionAuto:
|
||||
return "Auto"
|
||||
case FlashAttentionDisabled:
|
||||
return "Disabled"
|
||||
case FlashAttentionEnabled:
|
||||
return "Enabled"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variables
|
||||
// Set mustFilter true to enable filtering of CUDA devices
|
||||
|
||||
@@ -2,7 +2,6 @@ package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@@ -13,25 +12,26 @@ import (
|
||||
)
|
||||
|
||||
type TextConfig struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase float32
|
||||
largeModelScaling bool
|
||||
slidingWindowPattern []bool
|
||||
ropeBase float32
|
||||
ropeType string
|
||||
ropeOriginalContext int
|
||||
ropeExtrapolation float32
|
||||
ropeBetaFast float32
|
||||
ropeBetaSlow float32
|
||||
finalLogitSoftcap float32
|
||||
hiddenSize, contextLength, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase float32
|
||||
largeModelScaling bool
|
||||
slidingWindow uint32
|
||||
slidingWindowPattern []bool
|
||||
ropeBase float32
|
||||
ropeType string
|
||||
ropeOriginalContext int
|
||||
ropeExtrapolation float32
|
||||
ropeBetaFast float32
|
||||
ropeBetaSlow float32
|
||||
finalLogitSoftcap float32
|
||||
}
|
||||
|
||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
|
||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale))))
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithOriginalContextLength(o.ropeOriginalContext),
|
||||
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
||||
@@ -41,7 +41,7 @@ func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positi
|
||||
)
|
||||
}
|
||||
|
||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
|
||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...)
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -55,6 +55,9 @@ type TextModel struct {
|
||||
|
||||
const (
|
||||
gemmaGlobalCacheCount = 6
|
||||
gemma1BLayerCount = 26
|
||||
gemma4BLayerCount = 34
|
||||
gemma12BLayerCount = 48
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
@@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextConfig: &TextConfig{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
contextLength: int(c.Uint("context_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
@@ -77,6 +81,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeBase: c.Float("rope.freq_base", 1000000.0),
|
||||
slidingWindow: c.Uint("attention.sliding_window"),
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
|
||||
@@ -88,14 +93,20 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
},
|
||||
}
|
||||
|
||||
// Google's Gemma 3 release with sliding window attention does
|
||||
// not use final logit softcapping, and so force it to 0.0
|
||||
// TODO (jmorganca): this should ideally be set to 0.0 in the
|
||||
// model configuration instead of here, as future versions of
|
||||
// models may include both sliding window attention and final
|
||||
// logit softcapping.
|
||||
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
|
||||
m.TextConfig.finalLogitSoftcap = 0.0
|
||||
// Apply corrections for older versions of the Gemma 3 models
|
||||
// by looking at whether they use sliding window attention and
|
||||
// based on their layer counts.
|
||||
if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
|
||||
switch numBlocks {
|
||||
case gemma1BLayerCount:
|
||||
// The 1B model has final logit softcapping set to 30.0
|
||||
// but it should be 0.0
|
||||
m.TextConfig.finalLogitSoftcap = 0.0
|
||||
case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
|
||||
// The 4B, 12B, and 27B models have rope scale unset
|
||||
// but it shuold be set to 8.0
|
||||
m.TextConfig.ropeScale = 8.0
|
||||
}
|
||||
}
|
||||
|
||||
if numBlocks == gemma27BLayerCount {
|
||||
@@ -114,31 +125,31 @@ type TextSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
|
||||
func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) {
|
||||
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
|
||||
return opts.ropeLocalBase
|
||||
return opts.ropeLocalBase, 1.0
|
||||
}
|
||||
|
||||
// Standard Gemma3: only every n-th layer is global,
|
||||
// where n = gemmaGlobalCacheCount, otherwise use
|
||||
// the local rope base
|
||||
if (layer+1)%gemmaGlobalCacheCount > 0 {
|
||||
return opts.ropeLocalBase
|
||||
return opts.ropeLocalBase, 1.0
|
||||
}
|
||||
|
||||
// default to global rope base
|
||||
return opts.ropeBase
|
||||
return opts.ropeBase, opts.ropeScale
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
|
||||
ropeBase := opts.ropeBaseForLayer(layer)
|
||||
ropeBase, ropeScale := opts.ropeValuesForLayer(layer)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
|
||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -149,7 +160,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
|
||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -162,7 +173,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
|
||||
ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer)
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
@@ -17,10 +18,30 @@ type TextOptions struct {
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeOrigPosEmbeddings int
|
||||
ropeScalingBeta float32
|
||||
ropeType string
|
||||
ropeExtrapolation float32
|
||||
ropeBetaFast float32
|
||||
ropeBetaSlow float32
|
||||
ropeMscale float32
|
||||
ropeMscaleAllDim float32
|
||||
}
|
||||
|
||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale)
|
||||
var ropeOpts []func(*rope.Options)
|
||||
if o.ropeType == "yarn" {
|
||||
if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 {
|
||||
ropeOpts = append(ropeOpts, rope.WithAttentionFactor(1.0/float32(0.1*math.Log(float64(o.ropeScale))+1.0)))
|
||||
}
|
||||
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings),
|
||||
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
||||
rope.WithBetaFast(o.ropeBetaFast),
|
||||
rope.WithBetaSlow(o.ropeBetaSlow),
|
||||
)
|
||||
}
|
||||
|
||||
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...)
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -150,9 +171,15 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
|
||||
ropeScalingBeta: c.Float("rope.scaling_beta"),
|
||||
ropeScalingBeta: c.Float("rope.scaling_beta", 0.1),
|
||||
ropeBetaFast: c.Float("rope.scaling.beta_fast", 32.0),
|
||||
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeMscale: c.Float("rope.scaling.mscale"),
|
||||
ropeMscaleAllDim: c.Float("rope.scaling.mscale_all_dim"),
|
||||
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
292
model/parsers/deepseek.go
Normal file
292
model/parsers/deepseek.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type DeepSeekParserState int
|
||||
|
||||
const (
|
||||
DeepSeekCollectingThinking DeepSeekParserState = iota
|
||||
DeepSeekCollectingContent
|
||||
DeepSeekCollectingToolCalls
|
||||
DeepSeekCollectingToolOutput
|
||||
)
|
||||
|
||||
const (
|
||||
deepseekThinkingCloseTag = "</think>"
|
||||
deepseekToolCallsBeginTag = "<|tool▁calls▁begin|>"
|
||||
deepseekToolCallsEndTag = "<|tool▁calls▁end|>"
|
||||
deepseekToolCallBeginTag = "<|tool▁call▁begin|>"
|
||||
deepseekToolCallEndTag = "<|tool▁call▁end|>"
|
||||
deepseekToolSepTag = "<|tool▁sep|>"
|
||||
deepseekToolOutputBeginTag = "<|tool▁output▁begin|>"
|
||||
deepseekToolOutputEndTag = "<|tool▁output▁end|>"
|
||||
)
|
||||
|
||||
type DeepSeekParser struct {
|
||||
state DeepSeekParserState
|
||||
buffer strings.Builder
|
||||
hasThinkingSupport bool
|
||||
}
|
||||
|
||||
func (p *DeepSeekParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *DeepSeekParser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *DeepSeekParser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
// Check both model capability AND request preference
|
||||
thinkingEnabled := p.HasThinkingSupport() && (thinkValue == nil || thinkValue.Bool())
|
||||
|
||||
if !thinkingEnabled {
|
||||
p.state = DeepSeekCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = DeepSeekCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = DeepSeekCollectingThinking
|
||||
}
|
||||
|
||||
func (p *DeepSeekParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.setInitialState(lastMessage, tools, thinkValue)
|
||||
return tools
|
||||
}
|
||||
|
||||
type deepseekEvent interface {
|
||||
isDeepSeekEvent()
|
||||
}
|
||||
|
||||
type deepseekEventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type deepseekEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type deepseekEventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (deepseekEventThinkingContent) isDeepSeekEvent() {}
|
||||
func (deepseekEventContent) isDeepSeekEvent() {}
|
||||
func (deepseekEventToolCall) isDeepSeekEvent() {}
|
||||
|
||||
func (p *DeepSeekParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case deepseekEventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case deepseekEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case deepseekEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *DeepSeekParser) parseEvents() []deepseekEvent {
|
||||
var all []deepseekEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []deepseekEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *DeepSeekParser) eat() ([]deepseekEvent, bool) {
|
||||
var events []deepseekEvent
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case DeepSeekCollectingThinking:
|
||||
if strings.Contains(bufStr, deepseekThinkingCloseTag) { // thinking[</think>] -> content
|
||||
split := strings.SplitN(bufStr, deepseekThinkingCloseTag, 2)
|
||||
thinking := split[0]
|
||||
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
|
||||
|
||||
remaining := split[1]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = DeepSeekCollectingContent
|
||||
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, deepseekEventThinkingContent{content: thinking})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, deepseekThinkingCloseTag); overlapLen > 0 { // partial </think>
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, deepseekEventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else { // otherwise its thinking content
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, deepseekEventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case DeepSeekCollectingContent:
|
||||
switch {
|
||||
case strings.Contains(bufStr, deepseekToolCallsBeginTag): // content[<|tool▁calls▁begin|>] -> tool calls
|
||||
split := strings.SplitN(bufStr, deepseekToolCallsBeginTag, 2)
|
||||
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = DeepSeekCollectingToolCalls
|
||||
|
||||
if len(contentBefore) > 0 {
|
||||
events = append(events, deepseekEventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
case strings.Contains(bufStr, deepseekToolOutputBeginTag): // content[<|tool▁output▁begin|>] -> tool output
|
||||
split := strings.SplitN(bufStr, deepseekToolOutputBeginTag, 2)
|
||||
contentBefore := split[0] // Don't trim whitespace - preserve spaces
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = DeepSeekCollectingToolOutput
|
||||
|
||||
if len(contentBefore) > 0 {
|
||||
events = append(events, deepseekEventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
default: // otherwise its content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, deepseekEventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case DeepSeekCollectingToolCalls:
|
||||
if idx := strings.Index(bufStr, deepseekToolCallBeginTag); idx != -1 {
|
||||
startIdx := idx + len(deepseekToolCallBeginTag)
|
||||
if endIdx := strings.Index(bufStr[startIdx:], deepseekToolCallEndTag); endIdx != -1 {
|
||||
toolCallContent := bufStr[startIdx : startIdx+endIdx]
|
||||
|
||||
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
|
||||
remaining := bufStr[startIdx+endIdx+len(deepseekToolCallEndTag):]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
events = append(events, deepseekEventToolCall{toolCall: toolCall})
|
||||
return events, true
|
||||
} else {
|
||||
slog.Warn("deepseek tool call parsing failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idx := strings.Index(bufStr, deepseekToolCallsEndTag); idx != -1 {
|
||||
remaining := bufStr[idx+len(deepseekToolCallsEndTag):]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = DeepSeekCollectingContent
|
||||
|
||||
return events, true
|
||||
}
|
||||
|
||||
return events, false
|
||||
|
||||
case DeepSeekCollectingToolOutput:
|
||||
if idx := strings.Index(bufStr, deepseekToolOutputEndTag); idx != -1 {
|
||||
toolOutputContent := bufStr[:idx]
|
||||
remaining := bufStr[idx+len(deepseekToolOutputEndTag):]
|
||||
// Don't trim whitespace - preserve spaces after tool output tags
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = DeepSeekCollectingContent
|
||||
|
||||
if len(toolOutputContent) > 0 {
|
||||
events = append(events, deepseekEventContent{content: toolOutputContent})
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
func (p *DeepSeekParser) parseToolCallContent(content string) (api.ToolCall, error) {
|
||||
// Expected format: tool_name<|tool▁sep|>{args}
|
||||
parts := strings.SplitN(content, deepseekToolSepTag, 2)
|
||||
if len(parts) < 2 {
|
||||
return api.ToolCall{}, errors.New("invalid format")
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(parts[0])
|
||||
argsJSON := strings.TrimSpace(parts[1])
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
return api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
721
model/parsers/deepseek_test.go
Normal file
721
model/parsers/deepseek_test.go
Normal file
@@ -0,0 +1,721 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestDeepSeekParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
hasThinking bool
|
||||
}{
|
||||
{
|
||||
name: "simple_content",
|
||||
input: "Hello, how are you?",
|
||||
expectedContent: "Hello, how are you?",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "thinking_content",
|
||||
input: "I need to think about this...</think>The answer is 42.",
|
||||
expectedThinking: "I need to think about this...",
|
||||
expectedContent: "The answer is 42.",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "no_thinking_simple",
|
||||
input: "Just a regular response.",
|
||||
expectedContent: "Just a regular response.",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "thinking_with_newlines",
|
||||
input: "Let me think:\n- Point 1\n- Point 2</think>\n\nHere's my answer.",
|
||||
expectedThinking: "Let me think:\n- Point 1\n- Point 2",
|
||||
expectedContent: "Here's my answer.",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "tool_call_simple",
|
||||
input: "I'll check the weather.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "I'll check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
input: "Getting weather for both cities.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"London\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "Getting weather for both cities.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "London",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "tool_output",
|
||||
input: "Here's the weather: <|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|> Hope that helps!",
|
||||
expectedContent: "Here's the weather: Temperature: 22°C, Sunny Hope that helps!",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "complex_tool_arguments",
|
||||
input: "Processing data.<|tool▁calls▁begin|><|tool▁call▁begin|>process_data<|tool▁sep|>{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "Processing data.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []interface{}{"item1", "item2"},
|
||||
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "thinking_with_tool_call", // technically this can't happen, but the parser can handle it
|
||||
input: "Let me check the weather...</think>I'll get that for you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedThinking: "Let me check the weather...",
|
||||
expectedContent: "I'll get that for you.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "empty_content",
|
||||
input: "",
|
||||
expectedContent: "",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "only_thinking",
|
||||
input: "Just thinking content</think>",
|
||||
expectedThinking: "Just thinking content",
|
||||
expectedContent: "",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_outputs",
|
||||
input: "Results: <|tool▁output▁begin|>Paris: 22°C<|tool▁output▁end|> and <|tool▁output▁begin|>London: 18°C<|tool▁output▁end|>",
|
||||
expectedContent: "Results: Paris: 22°C and London: 18°C",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "unicode_content",
|
||||
input: "مرحبا بالعالم! 你好世界! 🌍",
|
||||
expectedContent: "مرحبا بالعالم! 你好世界! 🌍",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "emoji_passthrough",
|
||||
input: "Task completed ✅ 🎉",
|
||||
expectedContent: "Task completed ✅ 🎉",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "emoji_after_tool_call",
|
||||
input: "I'll help you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>完成 ✅",
|
||||
expectedContent: "I'll help you.完成 ✅",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "newlines_and_whitespace",
|
||||
input: "Line 1\n\nLine 3\t\tTabbed content",
|
||||
expectedContent: "Line 1\n\nLine 3\t\tTabbed content",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "thinking_with_unicode",
|
||||
input: "我在思考这个问题...</think>答案是42。",
|
||||
expectedThinking: "我在思考这个问题...",
|
||||
expectedContent: "答案是42。",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_unicode_args",
|
||||
input: "Searching for information.<|tool▁calls▁begin|><|tool▁call▁begin|>search<|tool▁sep|>{\"query\":\"北京天气\",\"language\":\"中文\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "Searching for information.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"query": "北京天气",
|
||||
"language": "中文",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "tool_output_with_unicode",
|
||||
input: "天气信息: <|tool▁output▁begin|>北京: 25°C, 晴天<|tool▁output▁end|> 希望对您有帮助!",
|
||||
expectedContent: "天气信息: 北京: 25°C, 晴天 希望对您有帮助!",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "mixed_content_with_special_chars",
|
||||
input: "Price: $100 & tax @ 10% = $110 <|tool▁output▁begin|>Total: $110<|tool▁output▁end|> (final)",
|
||||
expectedContent: "Price: $100 & tax @ 10% = $110 Total: $110 (final)",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_special_chars",
|
||||
input: "Processing data.<|tool▁calls▁begin|><|tool▁call▁begin|>execute_command<|tool▁sep|>{\"command\":\"ls && echo \\\"done\\\"\",\"path\":\"/home/user\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "Processing data.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute_command",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"command": "ls && echo \"done\"",
|
||||
"path": "/home/user",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "thinking_with_special_chars",
|
||||
input: "Let me calculate: 2+2=4 & 3*3=9...</think>The results are correct!",
|
||||
expectedThinking: "Let me calculate: 2+2=4 & 3*3=9...",
|
||||
expectedContent: "The results are correct!",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "empty_tool_call_args",
|
||||
input: "Pinging server.<|tool▁calls▁begin|><|tool▁call▁begin|>ping<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "Pinging server.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ping",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "empty_tool_output",
|
||||
input: "Checking status: <|tool▁output▁begin|><|tool▁output▁end|> No output received.",
|
||||
expectedContent: "Checking status: No output received.",
|
||||
hasThinking: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &DeepSeekParser{hasThinkingSupport: tt.hasThinking}
|
||||
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
|
||||
|
||||
content, thinking, calls, err := parser.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
|
||||
t.Errorf("Content mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
|
||||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" {
|
||||
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
hasThinking bool
|
||||
}{
|
||||
{
|
||||
name: "streaming_simple_content",
|
||||
chunks: []string{"Hello, ", "how are ", "you?"},
|
||||
expectedContent: "Hello, how are you?",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "streaming_thinking",
|
||||
chunks: []string{"I need to ", "think about this", "...</think>", "The answer is 42."},
|
||||
expectedThinking: "I need to think about this...",
|
||||
expectedContent: "The answer is 42.",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "streaming_tool_call",
|
||||
chunks: []string{"I'll check weather.", "<|tool▁calls▁begin|>", "<|tool▁call▁begin|>get_weather", "<|tool▁sep|>{\"location\":\"Paris\"}", "<|tool▁call▁end|><|tool▁calls▁end|>"},
|
||||
expectedContent: "I'll check weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "streaming_thinking_with_partial_tag",
|
||||
chunks: []string{"Thinking about this", "...</", "think>", "Done thinking."},
|
||||
expectedThinking: "Thinking about this...",
|
||||
expectedContent: "Done thinking.",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "streaming_tool_output",
|
||||
chunks: []string{"Weather info: ", "<|tool▁output▁begin|>", "25°C, Sunny", "<|tool▁output▁end|>", " Enjoy!"},
|
||||
expectedContent: "Weather info: 25°C, Sunny Enjoy!",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "streaming_with_split_tags",
|
||||
chunks: []string{"Content before ", "<|tool▁calls▁begin|><|tool▁call▁begin|>test", "<|tool▁sep|>{}", "<|tool▁call▁end|><|tool▁calls▁end|>", " after"},
|
||||
expectedContent: "Content before after",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "streaming_thinking_with_split_end_tag",
|
||||
chunks: []string{"Thinking content", "</th", "ink>", "Regular content"},
|
||||
expectedThinking: "Thinking content",
|
||||
expectedContent: "Regular content",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "streaming_unicode_content",
|
||||
chunks: []string{"مرحبا ", "بالعالم! ", "你好", "世界!"},
|
||||
expectedContent: "مرحبا بالعالم! 你好世界!",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "streaming_multiple_tool_outputs",
|
||||
chunks: []string{"Results: ", "<|tool▁output▁begin|>", "Paris: 22°C", "<|tool▁output▁end|>", " and ", "<|tool▁output▁begin|>", "London: 18°C", "<|tool▁output▁end|>"},
|
||||
expectedContent: "Results: Paris: 22°C and London: 18°C",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "streaming_tool_call_with_split_json",
|
||||
chunks: []string{"Processing.", "<|tool▁calls▁begin|><|tool▁call▁begin|>calc<|tool▁sep|>{\"x\":", "42,\"y\":", "24}<|tool▁call▁end|><|tool▁calls▁end|>"},
|
||||
expectedContent: "Processing.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calc",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"x": float64(42),
|
||||
"y": float64(24),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
hasThinking: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &DeepSeekParser{hasThinkingSupport: tt.hasThinking}
|
||||
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
|
||||
|
||||
var allContent, allThinking string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for i, chunk := range tt.chunks {
|
||||
done := i == len(tt.chunks)-1
|
||||
content, thinking, calls, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedContent, allContent); diff != "" {
|
||||
t.Errorf("Content mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedThinking, allThinking); diff != "" {
|
||||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" {
|
||||
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekParser_HasThinkingSupport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasThinking bool
|
||||
expectedSupport bool
|
||||
}{
|
||||
{
|
||||
name: "thinking_enabled",
|
||||
hasThinking: true,
|
||||
expectedSupport: true,
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled",
|
||||
hasThinking: false,
|
||||
expectedSupport: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &DeepSeekParser{hasThinkingSupport: tt.hasThinking}
|
||||
if got := parser.HasThinkingSupport(); got != tt.expectedSupport {
|
||||
t.Errorf("HasThinkingSupport() = %v, want %v", got, tt.expectedSupport)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekParser_HasToolSupport(t *testing.T) {
|
||||
parser := &DeepSeekParser{}
|
||||
if !parser.HasToolSupport() {
|
||||
t.Error("HasToolSupport() should return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekParser_Init(t *testing.T) {
|
||||
parser := &DeepSeekParser{hasThinkingSupport: true}
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "test_tool",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
if diff := cmp.Diff(tools, returnedTools); diff != "" {
|
||||
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// Test initial state is set to thinking when enabled
|
||||
if parser.state != DeepSeekCollectingThinking {
|
||||
t.Errorf("Expected initial state to be DeepSeekCollectingThinking, got %v", parser.state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekParser_parseToolCallContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected api.ToolCall
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid_tool_call",
|
||||
content: "get_weather<|tool▁sep|>{\"location\":\"Paris\"}",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex_arguments",
|
||||
content: "process_data<|tool▁sep|>{\"items\":[\"a\",\"b\"],\"config\":{\"enabled\":true}}",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []interface{}{"a", "b"},
|
||||
"config": map[string]interface{}{"enabled": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty_arguments",
|
||||
content: "ping<|tool▁sep|>{}",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ping",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unicode_in_tool_name",
|
||||
content: "获取天气<|tool▁sep|>{\"城市\":\"北京\"}",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"城市": "北京",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special_chars_in_arguments",
|
||||
content: "execute<|tool▁sep|>{\"command\":\"ls && echo \\\"done\\\"\",\"path\":\"/home/user\"}",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"command": "ls && echo \"done\"",
|
||||
"path": "/home/user",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "numeric_arguments",
|
||||
content: "calculate<|tool▁sep|>{\"x\":3.14,\"y\":42,\"enabled\":true}",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid_format_no_separator",
|
||||
content: "get_weather{\"location\":\"Paris\"}",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid_json",
|
||||
content: "get_weather<|tool▁sep|>{invalid json}",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty_tool_name",
|
||||
content: "<|tool▁sep|>{\"arg\":\"value\"}",
|
||||
expectError: false, // This should work, just empty name
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"arg": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_json_part",
|
||||
content: "tool_name<|tool▁sep|>",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
parser := &DeepSeekParser{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parser.parseToolCallContent(tt.content)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekParser_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
hasThinking bool
|
||||
}{
|
||||
{
|
||||
name: "nested_think_tags_in_thinking",
|
||||
input: "Outer thinking <think>inner</think> content</think>Final content",
|
||||
expectedThinking: "Outer thinking <think>inner",
|
||||
expectedContent: "content</think>Final content",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "multiple_think_close_tags",
|
||||
input: "First thought</think>Second thought</think>Final content",
|
||||
expectedThinking: "First thought",
|
||||
expectedContent: "Second thought</think>Final content",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "empty_thinking_content",
|
||||
input: "</think>Just content",
|
||||
expectedThinking: "",
|
||||
expectedContent: "Just content",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled_with_think_tags",
|
||||
input: "Some content</think>More content",
|
||||
expectedContent: "Some content</think>More content",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "malformed_tool_call_missing_sep",
|
||||
input: "Testing.<|tool▁calls▁begin|><|tool▁call▁begin|>bad_tool{\"arg\":\"value\"}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "Testing.",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "malformed_tool_call_invalid_json",
|
||||
input: "Testing.<|tool▁calls▁begin|><|tool▁call▁begin|>bad_tool<|tool▁sep|>{invalid json}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "Testing.",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "partial_tool_tag_at_end",
|
||||
input: "Content with partial <|tool▁calls▁",
|
||||
expectedContent: "Content with partial <|tool▁calls▁",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "partial_think_tag_at_end",
|
||||
input: "Thinking content</th",
|
||||
expectedContent: "Thinking content</th",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "partial_think_tag_at_end_with_thinking",
|
||||
input: "Thinking content</th",
|
||||
expectedThinking: "Thinking content",
|
||||
expectedContent: "",
|
||||
hasThinking: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace_only_content",
|
||||
input: " \n\t ",
|
||||
expectedContent: " \n\t ",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "tool_output_with_newlines",
|
||||
input: "Output:\n<|tool▁output▁begin|>Line 1\nLine 2\nLine 3<|tool▁output▁end|>\nDone.",
|
||||
expectedContent: "Output:\nLine 1\nLine 2\nLine 3\nDone.",
|
||||
hasThinking: false,
|
||||
},
|
||||
{
|
||||
name: "consecutive_tool_calls",
|
||||
input: "First.<|tool▁calls▁begin|><|tool▁call▁begin|>tool1<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>Second.<|tool▁calls▁begin|><|tool▁call▁begin|>tool2<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedContent: "First.",
|
||||
hasThinking: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &DeepSeekParser{hasThinkingSupport: tt.hasThinking}
|
||||
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
|
||||
|
||||
content, thinking, _, err := parser.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
|
||||
t.Errorf("Content mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
|
||||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -58,6 +58,8 @@ func ParserForName(name string) Parser {
|
||||
return harmony.NewHarmonyMessageHandler()
|
||||
case "cogito":
|
||||
return &CogitoParser{}
|
||||
case "deepseek":
|
||||
return &DeepSeekParser{hasThinkingSupport: true}
|
||||
case "olmo3":
|
||||
return &Olmo3Parser{}
|
||||
case "olmo3-think":
|
||||
|
||||
@@ -10,12 +10,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
|
||||
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
|
||||
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
|
||||
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
|
||||
olmo31DefaultSystemMessage = "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai. "
|
||||
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
|
||||
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
|
||||
)
|
||||
|
||||
type Olmo3Renderer struct{}
|
||||
type Olmo3Renderer struct {
|
||||
UseExtendedSystemMessage bool
|
||||
}
|
||||
|
||||
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
@@ -51,7 +54,11 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
||||
} else {
|
||||
// Default system message - single newline after "system"
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(olmo3DefaultSystemMessage)
|
||||
if r.UseExtendedSystemMessage {
|
||||
sb.WriteString(olmo31DefaultSystemMessage)
|
||||
} else {
|
||||
sb.WriteString(olmo3DefaultSystemMessage)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
@@ -140,7 +147,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
||||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n\n")
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
|
||||
@@ -24,7 +24,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
"<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "with system message no tools",
|
||||
@@ -36,7 +36,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
"You are a helpful assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
"<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "with system message and tools",
|
||||
@@ -64,7 +64,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
"<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "default system with tools - includes function instruction",
|
||||
@@ -93,7 +93,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
"<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls - function call syntax",
|
||||
@@ -141,7 +141,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
"<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
@@ -159,7 +159,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
"Hi there!<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"How are you?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
"<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls - newline separated",
|
||||
@@ -214,7 +214,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 55}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
"<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple arguments",
|
||||
@@ -259,7 +259,7 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||
"Book a flight<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
"<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "assistant prefill - no generation prompt",
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type Olmo3ThinkVariant int
|
||||
|
||||
const (
|
||||
olmo3ThinkDefaultSystemMessage = "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai."
|
||||
olmo3ThinkNoFunctionsMessage = " You do not currently have access to any functions."
|
||||
// Olmo3Think32B is for allenai/Olmo-3-32B-Think
|
||||
Olmo3Think32B Olmo3ThinkVariant = iota
|
||||
// Olmo31Think is for allenai/Olmo-3-7B-Think and allenai/Olmo-3.1-32B-Think (includes model info)
|
||||
Olmo31Think
|
||||
)
|
||||
|
||||
type Olmo3ThinkRenderer struct{}
|
||||
const (
|
||||
olmo3ThinkFunctionsSuffix = " You do not currently have access to any functions. <functions></functions>"
|
||||
olmo3Think32BSystemMessage = "You are a helpful AI assistant."
|
||||
olmo31ThinkSystemMessage = "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai."
|
||||
)
|
||||
|
||||
type olmo3ThinkToolCall struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function olmo3ThinkToolCallFunc `json:"function"`
|
||||
type Olmo3ThinkRenderer struct {
|
||||
Variant Olmo3ThinkVariant
|
||||
}
|
||||
|
||||
type olmo3ThinkToolCallFunc struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
func (r *Olmo3ThinkRenderer) Render(messages []api.Message, _ []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
var systemMessage *api.Message
|
||||
@@ -37,34 +37,31 @@ func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Skip tool messages - Think models don't support tools
|
||||
if message.Role == "tool" {
|
||||
continue
|
||||
}
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
}
|
||||
|
||||
systemContent := olmo3ThinkDefaultSystemMessage
|
||||
if systemMessage != nil {
|
||||
systemContent = systemMessage.Content
|
||||
}
|
||||
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(systemContent)
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(" <functions>")
|
||||
sb.WriteString(string(functionsJSON))
|
||||
sb.WriteString("</functions>")
|
||||
if systemMessage != nil {
|
||||
sb.WriteString(systemMessage.Content)
|
||||
sb.WriteString(olmo3ThinkFunctionsSuffix)
|
||||
} else {
|
||||
sb.WriteString(olmo3ThinkNoFunctionsMessage)
|
||||
sb.WriteString(" <functions></functions>")
|
||||
// Default system message varies by variant
|
||||
switch r.Variant {
|
||||
case Olmo3Think32B:
|
||||
sb.WriteString(olmo3Think32BSystemMessage)
|
||||
default: // Olmo3Think7B, Olmo31Think use same template - diverges from HF but confirmed difference from team
|
||||
sb.WriteString(olmo31ThinkSystemMessage)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
for i, message := range filteredMessages {
|
||||
lastMessage := i == len(filteredMessages)-1
|
||||
|
||||
for _, message := range filteredMessages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
@@ -73,58 +70,15 @@ func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
toolCalls := make([]olmo3ThinkToolCall, len(message.ToolCalls))
|
||||
for j, tc := range message.ToolCalls {
|
||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
toolCalls[j] = olmo3ThinkToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Function: olmo3ThinkToolCallFunc{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
}
|
||||
}
|
||||
toolCallsJSON, err := marshalWithSpaces(toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString("<function_calls>")
|
||||
sb.WriteString(string(toolCallsJSON))
|
||||
sb.WriteString("</function_calls>")
|
||||
}
|
||||
|
||||
if !lastMessage {
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
case "tool":
|
||||
sb.WriteString("<|im_start|>environment\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
}
|
||||
|
||||
needsGenerationPrompt := true
|
||||
if len(filteredMessages) > 0 {
|
||||
lastMsg := filteredMessages[len(filteredMessages)-1]
|
||||
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
|
||||
needsGenerationPrompt = false
|
||||
}
|
||||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n<think>")
|
||||
}
|
||||
// Always add generation prompt with <think> tag for thinking models
|
||||
sb.WriteString("<|im_start|>assistant\n<think>")
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -11,24 +11,27 @@ import (
|
||||
func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
variant Olmo3ThinkVariant
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic without system - adds default system",
|
||||
name: "7b_basic_without_system",
|
||||
variant: Olmo31Think,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "with system message no tools",
|
||||
name: "7b_with_custom_system",
|
||||
variant: Olmo31Think,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
@@ -41,9 +44,9 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "with system message and tools",
|
||||
name: "7b_tools_ignored",
|
||||
variant: Olmo31Think,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
@@ -52,27 +55,20 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
name: "7b_tool_calls_and_tool_messages_ignored",
|
||||
variant: Olmo31Think,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
@@ -81,53 +77,33 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather in SF?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`Let me check the weather.<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}]</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"Let me check the weather.<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
name: "7b_multi_turn_conversation",
|
||||
variant: Olmo31Think,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
@@ -138,73 +114,56 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls",
|
||||
name: "32b_basic_without_system",
|
||||
variant: Olmo3Think32B,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Get weather in SF and NYC"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. <functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"You are a helpful AI assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Get weather in SF and NYC<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}, {"id": "call_2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"New York\"}"}}]</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 55}<|im_end|>` + "\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "assistant message only content no tool calls",
|
||||
name: "32b_with_custom_system_gets_suffix",
|
||||
variant: Olmo3Think32B,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Tell me a joke"},
|
||||
{Role: "assistant", Content: "Why did the chicken cross the road?"},
|
||||
{Role: "user", Content: "I don't know, why?"},
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Tell me a joke<|im_end|>\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Why did the chicken cross the road?<|im_end|>\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "31_basic_without_system",
|
||||
variant: Olmo31Think,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"I don't know, why?<|im_end|>\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "31_with_custom_system_gets_suffix",
|
||||
variant: Olmo31Think,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
@@ -212,7 +171,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Olmo3ThinkRenderer{}).Render(tt.msgs, tt.tools, nil)
|
||||
rendered, err := (&Olmo3ThinkRenderer{Variant: tt.variant}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -60,10 +60,18 @@ func rendererForName(name string) Renderer {
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
return renderer
|
||||
case "olmo3":
|
||||
renderer := &Olmo3Renderer{}
|
||||
renderer := &Olmo3Renderer{UseExtendedSystemMessage: false}
|
||||
return renderer
|
||||
case "olmo3.1":
|
||||
renderer := &Olmo3Renderer{UseExtendedSystemMessage: true}
|
||||
return renderer
|
||||
case "olmo3-think":
|
||||
renderer := &Olmo3ThinkRenderer{}
|
||||
// Used for Olmo-3-7B-Think and Olmo-3.1-32B-Think (same template)
|
||||
renderer := &Olmo3ThinkRenderer{Variant: Olmo31Think}
|
||||
return renderer
|
||||
case "olmo3-32b-think":
|
||||
// Used for Olmo-3-32B-Think
|
||||
renderer := &Olmo3ThinkRenderer{Variant: Olmo3Think32B}
|
||||
return renderer
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -487,29 +487,9 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
valid := false
|
||||
// support blank mime type to match api/chat taking just unadorned base64
|
||||
if strings.HasPrefix(url, "data:;base64,") {
|
||||
url = strings.TrimPrefix(url, "data:;base64,")
|
||||
valid = true
|
||||
}
|
||||
for _, t := range types {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
if strings.HasPrefix(url, prefix) {
|
||||
url = strings.TrimPrefix(url, prefix)
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return nil, errors.New("invalid image input")
|
||||
}
|
||||
|
||||
img, err := base64.StdEncoding.DecodeString(url)
|
||||
img, err := decodeImageURL(url)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid message format")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||
@@ -648,6 +628,35 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
if strings.HasPrefix(url, "data:;base64,") {
|
||||
url = strings.TrimPrefix(url, "data:;base64,")
|
||||
} else {
|
||||
valid := false
|
||||
for _, t := range types {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
if strings.HasPrefix(url, prefix) {
|
||||
url = strings.TrimPrefix(url, prefix)
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("invalid image input")
|
||||
}
|
||||
}
|
||||
|
||||
img, err := base64.StdEncoding.DecodeString(url)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid image input")
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
|
||||
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
||||
apiToolCalls := make([]api.ToolCall, len(toolCalls))
|
||||
|
||||
1015
openai/responses.go
Normal file
1015
openai/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
1842
openai/responses_test.go
Normal file
1842
openai/responses_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
)
|
||||
|
||||
@@ -832,7 +833,7 @@ func (s *Server) loadModel(
|
||||
ppath string,
|
||||
kvSize int,
|
||||
kvCacheType string,
|
||||
flashAttention bool,
|
||||
flashAttention ml.FlashAttentionType,
|
||||
threads int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
@@ -842,7 +843,7 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize, s.parallel, threads, flashAttention, kvCacheType)
|
||||
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -1203,16 +1203,22 @@ func (s *Server) allocModel(
|
||||
return errors.New("loras are not yet implemented")
|
||||
}
|
||||
|
||||
if s.model.Config().Cache == nil {
|
||||
if parallel > 1 {
|
||||
parallel = 1
|
||||
slog.Warn("model does not support caching, disabling parallel processing")
|
||||
}
|
||||
if s.batchSize < kvSize {
|
||||
s.batchSize = kvSize
|
||||
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize)
|
||||
}
|
||||
}
|
||||
|
||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.cache.enabled && parallel > 1 {
|
||||
parallel = 1
|
||||
slog.Warn("model does not support caching, disabling parallel processing")
|
||||
}
|
||||
|
||||
s.parallel = parallel
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
@@ -1532,6 +1532,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -2393,3 +2394,4 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
|
||||
@@ -127,6 +127,9 @@ var funcs = template.FuncMap{
|
||||
// Default format is YYYY-MM-DD
|
||||
return time.Now().Format("2006-01-02")
|
||||
},
|
||||
"yesterdayDate": func(args ...string) string {
|
||||
return time.Now().AddDate(0, 0, -1).Format("2006-01-02")
|
||||
},
|
||||
"toTypeScriptType": func(v any) string {
|
||||
if param, ok := v.(api.ToolProperty); ok {
|
||||
return param.ToTypeScriptType()
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
@@ -451,6 +452,72 @@ func TestExecuteWithSuffix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateFunctions(t *testing.T) {
|
||||
t.Run("currentDate", func(t *testing.T) {
|
||||
tmpl, err := Parse("{{- range .Messages }}{{ .Content }}{{ end }} Today is {{ currentDate }}")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := "Hello Today is " + time.Now().Format("2006-01-02")
|
||||
if b.String() != expected {
|
||||
t.Errorf("got %q, want %q", b.String(), expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("yesterdayDate", func(t *testing.T) {
|
||||
tmpl, err := Parse("{{- range .Messages }}{{ .Content }}{{ end }} Yesterday was {{ yesterdayDate }}")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := "Hello Yesterday was " + time.Now().AddDate(0, 0, -1).Format("2006-01-02")
|
||||
if b.String() != expected {
|
||||
t.Errorf("got %q, want %q", b.String(), expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("yesterdayDate format", func(t *testing.T) {
|
||||
tmpl, err := Parse("{{- range .Messages }}{{ end }}{{ yesterdayDate }}")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify the format matches YYYY-MM-DD
|
||||
result := b.String()
|
||||
if len(result) != 10 {
|
||||
t.Errorf("expected date length 10, got %d: %q", len(result), result)
|
||||
}
|
||||
|
||||
// Parse and verify it's a valid date
|
||||
parsed, err := time.Parse("2006-01-02", result)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse date %q: %v", result, err)
|
||||
}
|
||||
|
||||
// Verify it's yesterday
|
||||
yesterday := time.Now().AddDate(0, 0, -1)
|
||||
if parsed.Year() != yesterday.Year() || parsed.Month() != yesterday.Month() || parsed.Day() != yesterday.Day() {
|
||||
t.Errorf("expected yesterday's date, got %v", parsed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user