mirror of
https://github.com/ollama/ollama.git
synced 2026-01-17 20:11:14 -05:00
Compare commits
15 Commits
grace/addi
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ebbebdf3b1 | ||
|
|
49393385ca | ||
|
|
12ff2d1461 | ||
|
|
f90d968b8b | ||
|
|
c623b256a3 | ||
|
|
8c8fb2f9f0 | ||
|
|
6e00a0c89a | ||
|
|
55b1ee2557 | ||
|
|
51cb1155ba | ||
|
|
7c5b656bb3 | ||
|
|
bddb27ab5b | ||
|
|
172b5924af | ||
|
|
8852220f59 | ||
|
|
7325791599 | ||
|
|
522c11a763 |
@@ -341,7 +341,7 @@ type ToolFunctionParameters struct {
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
@@ -352,7 +352,7 @@ func (t *ToolFunctionParameters) String() string {
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters ToolFunctionParameters `json:"parameters,omitempty"`
|
||||
Parameters ToolFunctionParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
func (t *ToolFunction) String() string {
|
||||
@@ -554,6 +554,9 @@ type CreateRequest struct {
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// Requires is the minimum version of Ollama required by the model.
|
||||
Requires string `json:"requires,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
|
||||
@@ -604,6 +607,7 @@ type ShowResponse struct {
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
Requires string `json:"requires,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
|
||||
@@ -943,6 +943,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
if resp.Requires != "" {
|
||||
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||
}
|
||||
return
|
||||
})
|
||||
|
||||
|
||||
@@ -291,6 +291,31 @@ Weigh anchor!
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("min version", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := ` Model
|
||||
architecture test
|
||||
parameters 7B
|
||||
quantization FP16
|
||||
requires 0.14.0
|
||||
|
||||
`
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteHandler(t *testing.T) {
|
||||
|
||||
@@ -49,7 +49,8 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
|
||||
// temporary fix to handle gemma3 broken configs
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
|
||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||
}
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ INSTRUCTION arguments
|
||||
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||
| [`MESSAGE`](#message) | Specify message history. |
|
||||
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -248,6 +249,16 @@ MESSAGE user Is Ontario in Canada?
|
||||
MESSAGE assistant yes
|
||||
```
|
||||
|
||||
### REQUIRES
|
||||
|
||||
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
|
||||
|
||||
```
|
||||
REQUIRES <version>
|
||||
```
|
||||
|
||||
The version should be a valid Ollama version (e.g. 0.14.0).
|
||||
|
||||
## Notes
|
||||
|
||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||
|
||||
15
go.mod
15
go.mod
@@ -15,8 +15,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.36.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -29,7 +29,8 @@ require (
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/tools v0.30.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
@@ -76,11 +77,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/term v0.30.0
|
||||
golang.org/x/text v0.23.0
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
30
go.sum
30
go.sum
@@ -224,8 +224,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -255,6 +255,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -267,8 +269,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -278,8 +280,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -295,17 +297,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -319,8 +321,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -524,8 +524,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
// Use the size of one layer as a buffer
|
||||
layers := s.ggml.Tensors().GroupLayers()
|
||||
if blk0, ok := layers["blk.0"]; ok {
|
||||
buffer := blk0.Size() + kv[0]
|
||||
for i := range gpus {
|
||||
gpus[i].FreeMemory -= blk0.Size() + kv[0]
|
||||
if gpus[i].FreeMemory > buffer {
|
||||
gpus[i].FreeMemory -= buffer
|
||||
} else {
|
||||
gpus[i].FreeMemory = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
@@ -575,7 +580,11 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
projectorGPU = firstIntegrated
|
||||
}
|
||||
|
||||
gpus[projectorGPU].FreeMemory -= projectorWeights
|
||||
if gpus[projectorGPU].FreeMemory > projectorWeights {
|
||||
gpus[projectorGPU].FreeMemory -= projectorWeights
|
||||
} else {
|
||||
gpus[projectorGPU].FreeMemory = 0
|
||||
}
|
||||
}
|
||||
|
||||
var kvTotal uint64
|
||||
|
||||
323
model/parsers/functiongemma.go
Normal file
323
model/parsers/functiongemma.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type FunctionGemmaParserState int
|
||||
|
||||
const (
|
||||
FunctionGemmaCollectingContent FunctionGemmaParserState = iota
|
||||
FunctionGemmaCollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
functionGemmaFunctionCallOpen = "<start_function_call>"
|
||||
functionGemmaFunctionCallClose = "<end_function_call>"
|
||||
)
|
||||
|
||||
// This format uses <start_function_call>call:name{args}<end_function_call> for tool calls.
|
||||
type FunctionGemmaParser struct {
|
||||
state FunctionGemmaParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *FunctionGemmaParser) HasToolSupport() bool { return true }
|
||||
func (p *FunctionGemmaParser) HasThinkingSupport() bool { return false }
|
||||
|
||||
func (p *FunctionGemmaParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.state = FunctionGemmaCollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
type functionGemmaEvent interface {
|
||||
isFunctionGemmaEvent()
|
||||
}
|
||||
|
||||
type FunctionGemmaEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type functionGemmaEventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (FunctionGemmaEventContent) isFunctionGemmaEvent() {}
|
||||
func (functionGemmaEventToolCall) isFunctionGemmaEvent() {}
|
||||
|
||||
func (p *FunctionGemmaParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case functionGemmaEventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case FunctionGemmaEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), "", toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *FunctionGemmaParser) parseEvents() []functionGemmaEvent {
|
||||
var all []functionGemmaEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []functionGemmaEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
|
||||
func (p *FunctionGemmaParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
|
||||
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
return beforePartialTag, bufStr[len(beforePartialTag):]
|
||||
}
|
||||
return bufStr, ""
|
||||
}
|
||||
|
||||
func (p *FunctionGemmaParser) eat() ([]functionGemmaEvent, bool) {
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case FunctionGemmaCollectingContent:
|
||||
if strings.Contains(bufStr, functionGemmaFunctionCallOpen) {
|
||||
split := strings.SplitN(bufStr, functionGemmaFunctionCallOpen, 2)
|
||||
content := split[0]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(split[1])
|
||||
p.state = FunctionGemmaCollectingToolCalls
|
||||
if content != "" {
|
||||
return []functionGemmaEvent{FunctionGemmaEventContent{content: content}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, functionGemmaFunctionCallOpen)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []functionGemmaEvent{FunctionGemmaEventContent{content: unambig}}, false
|
||||
}
|
||||
return nil, false
|
||||
|
||||
case FunctionGemmaCollectingToolCalls:
|
||||
if strings.Contains(bufStr, functionGemmaFunctionCallClose) {
|
||||
split := strings.SplitN(bufStr, functionGemmaFunctionCallClose, 2)
|
||||
remaining := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
var events []functionGemmaEvent
|
||||
if tc, err := p.parseToolCall(split[0]); err == nil {
|
||||
events = append(events, functionGemmaEventToolCall{toolCall: tc})
|
||||
}
|
||||
|
||||
if !strings.Contains(remaining, functionGemmaFunctionCallOpen) {
|
||||
p.state = FunctionGemmaCollectingContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Matches call:function_name{args}
|
||||
var functionGemmaCallRegex = regexp.MustCompile(`call:([^{]+)\{(.*)\}`)
|
||||
|
||||
func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
// Extract function name and arguments
|
||||
match := functionGemmaCallRegex.FindStringSubmatch(content)
|
||||
if len(match) < 3 {
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
toolCall.Function.Name = match[1]
|
||||
argsStr := match[2]
|
||||
|
||||
// Parse arguments
|
||||
toolCall.Function.Arguments = p.parseArguments(argsStr)
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseArguments parses the key:value,key:value format
|
||||
func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments {
|
||||
args := make(api.ToolCallFunctionArguments)
|
||||
if argsStr == "" {
|
||||
return args
|
||||
}
|
||||
|
||||
// Split by comma, but handle nested structures
|
||||
parts := p.splitArguments(argsStr)
|
||||
|
||||
for _, part := range parts {
|
||||
// Find the first colon to split key:value
|
||||
colonIdx := strings.Index(part, ":")
|
||||
if colonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := part[:colonIdx]
|
||||
value := part[colonIdx+1:]
|
||||
|
||||
// Parse the value
|
||||
args[key] = p.parseValue(value)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// splitArguments splits arguments by comma, respecting nested structures
|
||||
func (p *FunctionGemmaParser) splitArguments(argsStr string) []string {
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
depth := 0
|
||||
inEscape := false
|
||||
|
||||
for i := 0; i < len(argsStr); i++ {
|
||||
ch := argsStr[i]
|
||||
|
||||
// Check for <escape> tags
|
||||
if i+8 <= len(argsStr) && argsStr[i:i+8] == "<escape>" {
|
||||
inEscape = !inEscape
|
||||
current.WriteString("<escape>")
|
||||
i += 7 // Skip the rest of <escape>
|
||||
continue
|
||||
}
|
||||
|
||||
if !inEscape {
|
||||
switch ch {
|
||||
case '{', '[':
|
||||
depth++
|
||||
current.WriteByte(ch)
|
||||
case '}', ']':
|
||||
depth--
|
||||
current.WriteByte(ch)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
continue
|
||||
}
|
||||
current.WriteByte(ch)
|
||||
default:
|
||||
current.WriteByte(ch)
|
||||
}
|
||||
} else {
|
||||
current.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// parseValue parses a single value from the FunctionGemma format
|
||||
func (p *FunctionGemmaParser) parseValue(value string) any {
|
||||
// Check for escaped string
|
||||
if strings.HasPrefix(value, "<escape>") && strings.HasSuffix(value, "<escape>") {
|
||||
// Remove the escape tags
|
||||
return value[8 : len(value)-8]
|
||||
}
|
||||
|
||||
// Check for boolean
|
||||
if value == "true" {
|
||||
return true
|
||||
}
|
||||
if value == "false" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for number
|
||||
if num, ok := parseNumber(value); ok {
|
||||
return num
|
||||
}
|
||||
|
||||
// Check for array
|
||||
if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") {
|
||||
return p.parseArray(value[1 : len(value)-1])
|
||||
}
|
||||
|
||||
// Check for object
|
||||
if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") {
|
||||
return p.parseObject(value[1 : len(value)-1])
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return value
|
||||
}
|
||||
|
||||
// parseArray parses an array value
|
||||
func (p *FunctionGemmaParser) parseArray(content string) []any {
|
||||
var result []any
|
||||
parts := p.splitArguments(content)
|
||||
for _, part := range parts {
|
||||
result = append(result, p.parseValue(part))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseObject parses an object value
|
||||
func (p *FunctionGemmaParser) parseObject(content string) map[string]any {
|
||||
result := make(map[string]any)
|
||||
parts := p.splitArguments(content)
|
||||
for _, part := range parts {
|
||||
colonIdx := strings.Index(part, ":")
|
||||
if colonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
key := part[:colonIdx]
|
||||
value := part[colonIdx+1:]
|
||||
result[key] = p.parseValue(value)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseNumber tries to parse a string as a number
|
||||
func parseNumber(s string) (any, bool) {
|
||||
// Try integer first
|
||||
var intVal int64
|
||||
if _, err := fmt.Sscanf(s, "%d", &intVal); err == nil {
|
||||
// Check if the entire string was consumed
|
||||
if fmt.Sprintf("%d", intVal) == s {
|
||||
return intVal, true
|
||||
}
|
||||
}
|
||||
|
||||
// Try float
|
||||
var floatVal float64
|
||||
if _, err := fmt.Sscanf(s, "%f", &floatVal); err == nil {
|
||||
return floatVal, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
423
model/parsers/functiongemma_test.go
Normal file
423
model/parsers/functiongemma_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFunctionGemmaParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
tools []api.Tool
|
||||
expectedCalls []api.ToolCall
|
||||
expectedText string
|
||||
}{
|
||||
{
|
||||
name: "plain_content",
|
||||
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
|
||||
expectedCalls: nil,
|
||||
expectedText: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "simple_tool_call",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "content_before_tool_call",
|
||||
chunks: []string{
|
||||
"L", "et", " ", "me", " ", "check", ".",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "Let me check.",
|
||||
},
|
||||
{
|
||||
name: "numeric_arguments",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "add", "{",
|
||||
"a", ":", "1", ",", "b", ":", "2",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: api.ToolCallFunctionArguments{"a": int64(1), "b": int64(2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "boolean_arguments",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "set", "_", "flag", "{",
|
||||
"enabled", ":", "true", ",", "verbose", ":", "false",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_flag",
|
||||
Arguments: api.ToolCallFunctionArguments{"enabled": true, "verbose": false},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "London", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "London"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "array_argument",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "process", "{",
|
||||
"items", ":", "[",
|
||||
"<", "escape", ">", "a", "<", "escape", ">", ",",
|
||||
"<", "escape", ">", "b", "<", "escape", ">", ",",
|
||||
"<", "escape", ">", "c", "<", "escape", ">",
|
||||
"]",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: api.ToolCallFunctionArguments{"items": []any{"a", "b", "c"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "object_argument",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "update", "{",
|
||||
"data", ":", "{",
|
||||
"name", ":", "<", "escape", ">", "test", "<", "escape", ">", ",",
|
||||
"value", ":", "42",
|
||||
"}",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"data": map[string]any{"name": "test", "value": int64(42)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "empty_input",
|
||||
chunks: []string{},
|
||||
expectedCalls: nil,
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "tool_call_with_no_arguments",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "time", "{", "}",
|
||||
"<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "content_with_angle_brackets",
|
||||
chunks: []string{
|
||||
"The", " ", "result", " ", "is", " ", "a", " ", "<", "value", ">", " ", "tag",
|
||||
},
|
||||
expectedCalls: nil,
|
||||
expectedText: "The result is a <value> tag",
|
||||
},
|
||||
{
|
||||
name: "float_argument",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "set", "_", "temp", "{",
|
||||
"value", ":", "3", ".", "14",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: api.ToolCallFunctionArguments{"value": 3.14},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "content_after_tool_call",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "test", "{", "}",
|
||||
"<", "end", "_", "function", "_", "call", ">",
|
||||
"Done", "!",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "Done!",
|
||||
},
|
||||
{
|
||||
name: "unicode_content_and_arguments",
|
||||
chunks: []string{
|
||||
"こんにちは", " ",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "greet", "{",
|
||||
"name", ":", "<", "escape", ">", "日本語", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "greet",
|
||||
Arguments: api.ToolCallFunctionArguments{"name": "日本語"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "こんにちは ",
|
||||
},
|
||||
{
|
||||
name: "multiple_params_sorted",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "search", "{",
|
||||
"query", ":", "<", "escape", ">", "test", "<", "escape", ">", ",",
|
||||
"limit", ":", "10", ",",
|
||||
"offset", ":", "0",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"query": "test",
|
||||
"limit": int64(10),
|
||||
"offset": int64(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "nested_object_argument",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "create", "{",
|
||||
"config", ":", "{",
|
||||
"settings", ":", "{",
|
||||
"enabled", ":", "true", ",",
|
||||
"name", ":", "<", "escape", ">", "test", "<", "escape", ">",
|
||||
"}",
|
||||
"}",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"config": map[string]any{
|
||||
"settings": map[string]any{
|
||||
"enabled": true,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "partial_start_tag_in_content",
|
||||
chunks: []string{
|
||||
"Hello", " ", "<", "start", " ", "world",
|
||||
},
|
||||
expectedCalls: nil,
|
||||
expectedText: "Hello <start world",
|
||||
},
|
||||
{
|
||||
name: "parallel_tool_calls",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "weather", "{",
|
||||
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "get", "_", "time", "{",
|
||||
"timezone", ":", "<", "escape", ">", "UTC", "<", "escape", ">",
|
||||
"}", "<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "",
|
||||
},
|
||||
{
|
||||
name: "content_between_tool_calls",
|
||||
chunks: []string{
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "first", "{", "}",
|
||||
"<", "end", "_", "function", "_", "call", ">",
|
||||
"Some", " ", "text", " ", "here",
|
||||
"<", "start", "_", "function", "_", "call", ">",
|
||||
"call", ":", "second", "{", "}",
|
||||
"<", "end", "_", "function", "_", "call", ">",
|
||||
},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "first",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "second",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedText: "Some text here",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &FunctionGemmaParser{}
|
||||
parser.Init(tt.tools, nil, nil)
|
||||
|
||||
var allContent string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for i, chunk := range tt.chunks {
|
||||
done := i == len(tt.chunks)-1
|
||||
content, _, calls, err := parser.Add(chunk, done)
|
||||
assert.NoError(t, err)
|
||||
allContent += content
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Handle empty chunks case
|
||||
if len(tt.chunks) == 0 {
|
||||
content, _, calls, err := parser.Add("", true)
|
||||
assert.NoError(t, err)
|
||||
allContent = content
|
||||
allCalls = calls
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedText, allContent)
|
||||
assert.Equal(t, tt.expectedCalls, allCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionGemmaParser_HasSupport(t *testing.T) {
|
||||
parser := &FunctionGemmaParser{}
|
||||
assert.True(t, parser.HasToolSupport())
|
||||
assert.False(t, parser.HasThinkingSupport())
|
||||
}
|
||||
@@ -66,6 +66,8 @@ func ParserForName(name string) Parser {
|
||||
return &Olmo3ThinkParser{}
|
||||
case "nemotron-3-nano":
|
||||
return &Nemotron3NanoParser{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaParser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
287
model/renderers/functiongemma.go
Normal file
287
model/renderers/functiongemma.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type FunctionGemmaRenderer struct{}
|
||||
|
||||
const defaultSystemMessage = "You can do function calling with the following functions:"
|
||||
|
||||
func (r *FunctionGemmaRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("<bos>")
|
||||
|
||||
var systemMessage string
|
||||
var loopMessages []api.Message
|
||||
if len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer") {
|
||||
systemMessage = messages[0].Content
|
||||
loopMessages = messages[1:]
|
||||
} else {
|
||||
loopMessages = messages
|
||||
}
|
||||
|
||||
if systemMessage != "" || len(tools) > 0 {
|
||||
sb.WriteString("<start_of_turn>developer\n")
|
||||
if systemMessage != "" {
|
||||
sb.WriteString(strings.TrimSpace(systemMessage))
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
if systemMessage != "" {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
if strings.TrimSpace(systemMessage) != defaultSystemMessage {
|
||||
// Only add default message if user does not provide it
|
||||
sb.WriteString(defaultSystemMessage)
|
||||
}
|
||||
}
|
||||
for _, tool := range tools {
|
||||
sb.WriteString(r.renderToolDeclaration(tool))
|
||||
}
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
|
||||
// Track previous message type for tool response handling
|
||||
prevMessageType := ""
|
||||
|
||||
for i, message := range loopMessages {
|
||||
switch message.Role {
|
||||
case "assistant":
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>model\n")
|
||||
}
|
||||
prevMessageType = ""
|
||||
|
||||
if message.Content != "" {
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, tc := range message.ToolCalls {
|
||||
sb.WriteString(r.formatToolCall(tc))
|
||||
}
|
||||
// After tool calls, expect tool responses
|
||||
if i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool" {
|
||||
sb.WriteString("<start_function_response>")
|
||||
prevMessageType = "tool_call"
|
||||
} else {
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
|
||||
case "user":
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>user\n")
|
||||
}
|
||||
prevMessageType = ""
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
|
||||
case "tool":
|
||||
toolName := ""
|
||||
// Find the tool name from the previous assistant's tool call
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if loopMessages[j].Role == "assistant" && len(loopMessages[j].ToolCalls) > 0 {
|
||||
// Count how many tool messages came before this one
|
||||
toolIdx := 0
|
||||
for k := j + 1; k < i; k++ {
|
||||
if loopMessages[k].Role == "tool" {
|
||||
toolIdx++
|
||||
}
|
||||
}
|
||||
if toolIdx < len(loopMessages[j].ToolCalls) {
|
||||
toolName = loopMessages[j].ToolCalls[toolIdx].Function.Name
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if prevMessageType != "tool_call" {
|
||||
sb.WriteString("<start_function_response>")
|
||||
}
|
||||
sb.WriteString("response:" + toolName + "{" + r.formatArgValue(message.Content) + "}<end_function_response>")
|
||||
prevMessageType = "tool_response"
|
||||
|
||||
default:
|
||||
sb.WriteString("<start_of_turn>" + message.Role + "\n")
|
||||
sb.WriteString(strings.TrimSpace(message.Content))
|
||||
sb.WriteString("<end_of_turn>\n")
|
||||
}
|
||||
}
|
||||
|
||||
if prevMessageType != "tool_response" {
|
||||
sb.WriteString("<start_of_turn>model\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
||||
var sb strings.Builder
|
||||
|
||||
fn := tool.Function
|
||||
sb.WriteString("<start_function_declaration>declaration:" + fn.Name + "{")
|
||||
sb.WriteString("description:<escape>" + fn.Description + "<escape>")
|
||||
|
||||
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
|
||||
sb.WriteString(",parameters:{")
|
||||
|
||||
needsComma := false
|
||||
|
||||
// Only include properties:{} if there are actual properties
|
||||
if len(fn.Parameters.Properties) > 0 {
|
||||
sb.WriteString("properties:{")
|
||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||
sb.WriteString("}")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if len(fn.Parameters.Required) > 0 {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("required:[")
|
||||
for i, req := range fn.Parameters.Required {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("<escape>" + req + "<escape>")
|
||||
}
|
||||
sb.WriteString("]")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if fn.Parameters.Type != "" {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:<escape>" + strings.ToUpper(fn.Parameters.Type) + "<escape>")
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
|
||||
sb.WriteString("}<end_function_declaration>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) {
|
||||
keys := make([]string, 0, len(props))
|
||||
for k := range props {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, name := range keys {
|
||||
prop := props[name]
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
|
||||
sb.WriteString(name + ":{description:<escape>")
|
||||
sb.WriteString(prop.Description)
|
||||
sb.WriteString("<escape>")
|
||||
|
||||
if len(prop.Type) > 0 {
|
||||
sb.WriteString(",type:<escape>" + strings.ToUpper(prop.Type[0]) + "<escape>")
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
|
||||
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value := tc.Function.Arguments[key]
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(value))
|
||||
}
|
||||
|
||||
sb.WriteString("}<end_function_call>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatArgValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return "<escape>" + v + "<escape>"
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case float64:
|
||||
if v == float64(int64(v)) {
|
||||
return fmt.Sprintf("%d", int64(v))
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case map[string]any:
|
||||
return r.formatMapValue(v)
|
||||
case []any:
|
||||
return r.formatArrayValue(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatMapValue(m map[string]any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("{")
|
||||
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) formatArrayValue(arr []any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[")
|
||||
|
||||
for i, item := range arr {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(r.formatArgValue(item))
|
||||
}
|
||||
|
||||
sb.WriteString("]")
|
||||
return sb.String()
|
||||
}
|
||||
514
model/renderers/functiongemma_test.go
Normal file
514
model/renderers/functiongemma_test.go
Normal file
@@ -0,0 +1,514 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFunctionGemmaRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic_user_message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "with_system_message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou are helpful<end_of_turn>\n<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "with_developer_role",
|
||||
messages: []api.Message{
|
||||
{Role: "developer", Content: "You are a coding assistant"},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou are a coding assistant<end_of_turn>\n<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "custom_system_message_with_tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a weather expert."},
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Custom system message is preserved, tools are appended
|
||||
expected: "<bos><start_of_turn>developer\nYou are a weather expert.\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "developer_role_with_tools",
|
||||
messages: []api.Message{
|
||||
{Role: "developer", Content: "Be concise."},
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Developer role message is preserved, tools are appended
|
||||
expected: "<bos><start_of_turn>developer\nBe concise.\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "multi_turn",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello!"},
|
||||
{Role: "user", Content: "More"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nHi<end_of_turn>\n<start_of_turn>model\nHello!<end_of_turn>\n<start_of_turn>user\nMore<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "with_tools",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "tool_call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "assistant_content_with_tool_call",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\nLet me check.<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "numeric_arguments",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Add"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: api.ToolCallFunctionArguments{"a": float64(1), "b": float64(2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "3"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "add",
|
||||
Description: "Add numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"a": {Type: api.PropertyType{"number"}},
|
||||
"b": {Type: api.PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:add{description:<escape>Add numbers<escape>,parameters:{properties:{a:{description:<escape><escape>,type:<escape>NUMBER<escape>},b:{description:<escape><escape>,type:<escape>NUMBER<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nAdd<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:add{a:1,b:2}<end_function_call><start_function_response>response:add{<escape>3<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "empty_messages",
|
||||
messages: []api.Message{},
|
||||
expected: "<bos><start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "tool_with_required_params",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Gets the weather for a given city",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City Name"},
|
||||
"country": {Type: api.PropertyType{"string"}, Description: "Country Name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Required params are escaped: required:[<escape>city<escape>]
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Gets the weather for a given city<escape>,parameters:{properties:{city:{description:<escape>City Name<escape>,type:<escape>STRING<escape>},country:{description:<escape>Country Name<escape>,type:<escape>STRING<escape>}},required:[<escape>city<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "multiple_tools",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather and time?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get current time",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Multiple tool declarations are consecutive
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:get_time{description:<escape>Get current time<escape>,parameters:{properties:{timezone:{description:<escape>Timezone<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather and time?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "parallel_tool_calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather and time?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
{Role: "tool", Content: "12:00"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get current time",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Multiple tool calls and responses are consecutive
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:get_time{description:<escape>Get current time<escape>,parameters:{properties:{timezone:{description:<escape>Timezone<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather and time?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_call>call:get_time{timezone:<escape>UTC<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response><start_function_response>response:get_time{<escape>12:00<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "user_after_tool_response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny"},
|
||||
{Role: "user", Content: "Thanks! What about London?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// User message after tool response gets concatenated (user reverted to this behavior)
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>Thanks! What about London?<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
name: "tool_empty_properties",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Test"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "test_fn",
|
||||
Description: "",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Empty properties are omitted
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test_fn{description:<escape><escape>,parameters:{type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "unicode_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "こんにちは 🎉"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nこんにちは 🎉<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "newlines_in_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Line 1\nLine 2\nLine 3"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nLine 1\nLine 2\nLine 3<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "special_chars_in_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Test <tag> & \"quotes\" chars"},
|
||||
},
|
||||
expected: "<bos><start_of_turn>user\nTest <tag> & \"quotes\" chars<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "boolean_argument",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Set flag"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_flag",
|
||||
Arguments: api.ToolCallFunctionArguments{"enabled": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "done"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "set_flag",
|
||||
Description: "Set a flag",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:set_flag{description:<escape>Set a flag<escape>,parameters:{properties:{enabled:{description:<escape>Flag value<escape>,type:<escape>BOOLEAN<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nSet flag<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:set_flag{enabled:true}<end_function_call><start_function_response>response:set_flag{<escape>done<escape>}<end_function_response>",
|
||||
},
|
||||
{
|
||||
name: "multiple_required_params",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Test"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "test",
|
||||
Description: "Test",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"a", "b", "c"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"a": {Type: api.PropertyType{"string"}, Description: "A"},
|
||||
"b": {Type: api.PropertyType{"string"}, Description: "B"},
|
||||
"c": {Type: api.PropertyType{"string"}, Description: "C"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test{description:<escape>Test<escape>,parameters:{properties:{a:{description:<escape>A<escape>,type:<escape>STRING<escape>},b:{description:<escape>B<escape>,type:<escape>STRING<escape>},c:{description:<escape>C<escape>,type:<escape>STRING<escape>}},required:[<escape>a<escape>,<escape>b<escape>,<escape>c<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "array_type_param",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Test"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "test",
|
||||
Description: "Test",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"items": {Type: api.PropertyType{"array"}, Description: "List of items"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test{description:<escape>Test<escape>,parameters:{properties:{items:{description:<escape>List of items<escape>,type:<escape>ARRAY<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &FunctionGemmaRenderer{}
|
||||
result, err := renderer.Render(tt.messages, tt.tools, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -78,6 +78,8 @@ func rendererForName(name string) Renderer {
|
||||
return renderer
|
||||
case "nemotron-3-nano":
|
||||
return &Nemotron3NanoRenderer{}
|
||||
case "functiongemma":
|
||||
return &FunctionGemmaRenderer{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
@@ -104,6 +105,16 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
req.Renderer = c.Args
|
||||
case "parser":
|
||||
req.Parser = c.Args
|
||||
case "requires":
|
||||
// golang.org/x/mod/semver requires "v" prefix
|
||||
requires := c.Args
|
||||
if !strings.HasPrefix(requires, "v") {
|
||||
requires = "v" + requires
|
||||
}
|
||||
if !semver.IsValid(requires) {
|
||||
return nil, fmt.Errorf("requires must be a valid semver (e.g. 0.14.0)")
|
||||
}
|
||||
req.Requires = strings.TrimPrefix(requires, "v")
|
||||
case "message":
|
||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
||||
@@ -322,7 +333,7 @@ func (c Command) String() string {
|
||||
switch c.Name {
|
||||
case "model":
|
||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||
case "license", "template", "system", "adapter", "renderer", "parser":
|
||||
case "license", "template", "system", "adapter", "renderer", "parser", "requires":
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(c.Args, ": ")
|
||||
@@ -348,7 +359,7 @@ const (
|
||||
var (
|
||||
errMissingFrom = errors.New("no FROM line")
|
||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
|
||||
)
|
||||
|
||||
type ParserError struct {
|
||||
@@ -608,7 +619,7 @@ func isValidMessageRole(role string) bool {
|
||||
|
||||
func isValidCommand(cmd string) bool {
|
||||
switch strings.ToLower(cmd) {
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -61,6 +61,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
||||
config.Renderer = r.Renderer
|
||||
config.Parser = r.Parser
|
||||
config.Requires = r.Requires
|
||||
|
||||
for v := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
@@ -120,7 +121,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "") {
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
|
||||
manifest, mErr := ParseNamedManifest(fromName)
|
||||
if mErr == nil && manifest.Config.Digest != "" {
|
||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||
@@ -134,6 +135,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
if config.Parser == "" {
|
||||
config.Parser = baseConfig.Parser
|
||||
}
|
||||
if config.Requires == "" {
|
||||
config.Requires = baseConfig.Requires
|
||||
}
|
||||
}
|
||||
cfgFile.Close()
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
@@ -31,9 +33,45 @@ const maxRetries = 6
|
||||
var (
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errPartSlow = errors.New("part slow, racing")
|
||||
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
|
||||
)
|
||||
|
||||
// speedTracker tracks download speeds and computes rolling median.
|
||||
type speedTracker struct {
|
||||
mu sync.Mutex
|
||||
speeds []float64 // bytes per second
|
||||
}
|
||||
|
||||
func (s *speedTracker) Record(bytesPerSec float64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.speeds = append(s.speeds, bytesPerSec)
|
||||
// Keep last 100 samples
|
||||
if len(s.speeds) > 100 {
|
||||
s.speeds = s.speeds[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *speedTracker) Median() float64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if len(s.speeds) < 3 {
|
||||
return 0 // not enough data
|
||||
}
|
||||
// Simple median: sort a copy and take middle
|
||||
sorted := make([]float64, len(s.speeds))
|
||||
copy(sorted, s.speeds)
|
||||
for i := range sorted {
|
||||
for j := i + 1; j < len(sorted); j++ {
|
||||
if sorted[j] < sorted[i] {
|
||||
sorted[i], sorted[j] = sorted[j], sorted[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return sorted[len(sorted)/2]
|
||||
}
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
|
||||
type blobDownload struct {
|
||||
@@ -94,26 +132,127 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
numDownloadParts = 16
|
||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||
var (
|
||||
downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte
|
||||
downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48)
|
||||
)
|
||||
|
||||
func envInt(key string, defaultVal int) int {
|
||||
if s := os.Getenv(key); s != "" {
|
||||
if v, err := strconv.Atoi(s); err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// streamHasher reads a file sequentially and hashes it as chunks complete.
|
||||
// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency.
|
||||
// Works by reading from OS page cache - data just written is still in RAM.
|
||||
type streamHasher struct {
|
||||
file *os.File
|
||||
hasher hash.Hash
|
||||
parts []*blobDownloadPart
|
||||
total int64 // total bytes to hash
|
||||
hashed atomic.Int64
|
||||
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
completed []bool
|
||||
done bool
|
||||
err error
|
||||
}
|
||||
|
||||
func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher {
|
||||
h := &streamHasher{
|
||||
file: file,
|
||||
hasher: sha256.New(),
|
||||
parts: parts,
|
||||
total: total,
|
||||
completed: make([]bool, len(parts)),
|
||||
}
|
||||
h.cond = sync.NewCond(&h.mu)
|
||||
return h
|
||||
}
|
||||
|
||||
// MarkComplete signals that a part has been written to disk.
|
||||
func (h *streamHasher) MarkComplete(partIndex int) {
|
||||
h.mu.Lock()
|
||||
h.completed[partIndex] = true
|
||||
h.cond.Broadcast()
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Run reads and hashes the file sequentially. Call in a goroutine.
|
||||
func (h *streamHasher) Run() {
|
||||
buf := make([]byte, 64*1024) // 64KB read buffer
|
||||
var offset int64
|
||||
|
||||
for i, part := range h.parts {
|
||||
// Wait for this part to be written
|
||||
h.mu.Lock()
|
||||
for !h.completed[i] && !h.done {
|
||||
h.cond.Wait()
|
||||
}
|
||||
if h.done {
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
// Read and hash this part (from page cache)
|
||||
remaining := part.Size
|
||||
for remaining > 0 {
|
||||
n := int64(len(buf))
|
||||
if n > remaining {
|
||||
n = remaining
|
||||
}
|
||||
nr, err := h.file.ReadAt(buf[:n], offset)
|
||||
if err != nil && err != io.EOF {
|
||||
h.mu.Lock()
|
||||
h.err = err
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
h.hasher.Write(buf[:nr])
|
||||
offset += int64(nr)
|
||||
remaining -= int64(nr)
|
||||
h.hashed.Store(offset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop signals the hasher to exit early.
|
||||
func (h *streamHasher) Stop() {
|
||||
h.mu.Lock()
|
||||
h.done = true
|
||||
h.cond.Broadcast()
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Hashed returns bytes hashed so far.
|
||||
func (h *streamHasher) Hashed() int64 {
|
||||
return h.hashed.Load()
|
||||
}
|
||||
|
||||
// Digest returns the computed hash.
|
||||
func (h *streamHasher) Digest() string {
|
||||
return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// Err returns any error from hashing.
|
||||
func (h *streamHasher) Err() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.err
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Name() string {
|
||||
return strings.Join([]string{
|
||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||
}, "-")
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StartsAt() int64 {
|
||||
return p.Offset + p.Completed.Load()
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StopsAt() int64 {
|
||||
return p.Offset + p.Size
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
@@ -151,14 +290,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||
|
||||
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||
|
||||
size := b.Total / numDownloadParts
|
||||
switch {
|
||||
case size < minDownloadPartSize:
|
||||
size = minDownloadPartSize
|
||||
case size > maxDownloadPartSize:
|
||||
size = maxDownloadPartSize
|
||||
}
|
||||
|
||||
size := downloadPartSize
|
||||
var offset int64
|
||||
for offset < b.Total {
|
||||
if offset+size > b.Total {
|
||||
@@ -220,9 +352,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
setSparse(file)
|
||||
|
||||
_ = file.Truncate(b.Total)
|
||||
|
||||
directURL, err := func() (*url.URL, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
@@ -270,44 +399,106 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
return err
|
||||
}
|
||||
|
||||
// Download chunks to disk, hash by reading from page cache.
|
||||
// Memory: ~64KB (hasher read buffer only), regardless of concurrency.
|
||||
// The hasher follows behind the downloaders, reading recently-written
|
||||
// data from OS page cache (RAM) rather than disk.
|
||||
sh := newStreamHasher(file, b.Parts, b.Total)
|
||||
tracker := &speedTracker{}
|
||||
|
||||
// Start hasher goroutine
|
||||
hashDone := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(hashDone)
|
||||
}()
|
||||
|
||||
// Log progress periodically
|
||||
// Page cache warning: if spread > 4GB, hasher may hit disk instead of RAM
|
||||
const pageCacheWarningBytes = 4 << 30 // 4GB
|
||||
progressDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
downloaded := b.Completed.Load()
|
||||
hashed := sh.Hashed()
|
||||
dlPct := int(downloaded * 100 / b.Total)
|
||||
hPct := int(hashed * 100 / b.Total)
|
||||
spread := dlPct - hPct
|
||||
spreadBytes := downloaded - hashed
|
||||
|
||||
slog.Debug(fmt.Sprintf("progress: downloaded %d%% | hashed %d%% | spread %d%%", dlPct, hPct, spread))
|
||||
if spreadBytes > pageCacheWarningBytes {
|
||||
slog.Debug("page cache pressure", "ahead", fmt.Sprintf("%.1fGB", float64(spreadBytes)/(1<<30)))
|
||||
}
|
||||
case <-progressDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(numDownloadParts)
|
||||
g.SetLimit(downloadConcurrency)
|
||||
for i := range b.Parts {
|
||||
part := b.Parts[i]
|
||||
if part.Completed.Load() == part.Size {
|
||||
sh.MarkComplete(part.N)
|
||||
continue
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
var slowRetries int
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||
err = b.downloadChunk(inner, directURL, w, part)
|
||||
// After 3 slow retries, stop checking slowness and let it complete
|
||||
skipSlowCheck := slowRetries >= 3
|
||||
err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
return err
|
||||
case errors.Is(err, errPartStalled):
|
||||
try--
|
||||
continue
|
||||
case errors.Is(err, errPartSlow):
|
||||
// Kill slow request, retry immediately (stays within concurrency limit)
|
||||
slowRetries++
|
||||
try--
|
||||
continue
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
time.Sleep(sleep)
|
||||
continue
|
||||
default:
|
||||
sh.MarkComplete(part.N)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
close(progressDone)
|
||||
sh.Stop()
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for hasher to finish
|
||||
<-hashDone
|
||||
close(progressDone)
|
||||
if err := sh.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify hash
|
||||
if computed := sh.Digest(); computed != b.Digest {
|
||||
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
|
||||
}
|
||||
|
||||
// explicitly close the file so we can rename it
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
@@ -326,38 +517,69 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
|
||||
// downloadChunkToDisk streams a part directly to disk at its offset.
|
||||
// Memory: ~32KB (read buffer only).
|
||||
// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries).
|
||||
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
startTime := time.Now()
|
||||
var bytesAtLastCheck atomic.Int64
|
||||
|
||||
g.Go(func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
w := io.NewOffsetWriter(file, part.Offset)
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
var written int64
|
||||
for written < part.Size {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
written += int64(n)
|
||||
b.Completed.Add(int64(n))
|
||||
bytesAtLastCheck.Store(written)
|
||||
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Now()
|
||||
part.lastUpdatedMu.Unlock()
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
b.Completed.Add(-written)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
part.Completed.Add(n)
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
// Record speed for this part
|
||||
elapsed := time.Since(startTime).Seconds()
|
||||
if elapsed > 0 {
|
||||
tracker.Record(float64(part.Size) / elapsed)
|
||||
}
|
||||
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
part.Completed.Store(part.Size)
|
||||
return b.writePart(part.Name(), part)
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
var lastBytes int64
|
||||
checksWithoutProgress := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
@@ -365,19 +587,47 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
||||
return nil
|
||||
}
|
||||
|
||||
currentBytes := bytesAtLastCheck.Load()
|
||||
|
||||
// Check for complete stall (30 seconds no progress)
|
||||
part.lastUpdatedMu.Lock()
|
||||
lastUpdated := part.lastUpdated
|
||||
part.lastUpdatedMu.Unlock()
|
||||
|
||||
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Time{}
|
||||
part.lastUpdatedMu.Unlock()
|
||||
return errPartStalled
|
||||
}
|
||||
|
||||
// Check for slow speed after 5+ seconds (only for multi-part downloads)
|
||||
// Skip if we've already retried for slowness too many times
|
||||
elapsed := time.Since(startTime).Seconds()
|
||||
if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 {
|
||||
currentSpeed := float64(currentBytes) / elapsed
|
||||
median := tracker.Median()
|
||||
|
||||
// If we're below 10% of median speed, flag as slow
|
||||
if median > 0 && currentSpeed < median*0.1 {
|
||||
slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying",
|
||||
b.Digest[7:19], part.N, currentSpeed/1024, median/1024))
|
||||
return errPartSlow
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if speed dropped significantly mid-download
|
||||
if currentBytes == lastBytes {
|
||||
checksWithoutProgress++
|
||||
if checksWithoutProgress >= 10 {
|
||||
slog.Info(fmt.Sprintf("%s part %d no progress for 10s; retrying", b.Digest[7:19], part.N))
|
||||
return errPartStalled
|
||||
}
|
||||
} else {
|
||||
checksWithoutProgress = 0
|
||||
}
|
||||
lastBytes = currentBytes
|
||||
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
319
server/download_test.go
Normal file
319
server/download_test.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSpeedTracker_Median(t *testing.T) {
|
||||
s := &speedTracker{}
|
||||
|
||||
// Less than 3 samples returns 0
|
||||
s.Record(100)
|
||||
s.Record(200)
|
||||
if got := s.Median(); got != 0 {
|
||||
t.Errorf("expected 0 with < 3 samples, got %f", got)
|
||||
}
|
||||
|
||||
// With 3+ samples, returns median
|
||||
s.Record(300)
|
||||
// Samples: [100, 200, 300] -> median = 200
|
||||
if got := s.Median(); got != 200 {
|
||||
t.Errorf("expected median 200, got %f", got)
|
||||
}
|
||||
|
||||
// Add more samples
|
||||
s.Record(50)
|
||||
s.Record(250)
|
||||
// Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200
|
||||
if got := s.Median(); got != 200 {
|
||||
t.Errorf("expected median 200, got %f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpeedTracker_RollingWindow(t *testing.T) {
|
||||
s := &speedTracker{}
|
||||
|
||||
// Add 105 samples (should keep only last 100)
|
||||
for i := 0; i < 105; i++ {
|
||||
s.Record(float64(i))
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if len(s.speeds) != 100 {
|
||||
t.Errorf("expected 100 samples, got %d", len(s.speeds))
|
||||
}
|
||||
// First sample should be 5 (0-4 were dropped)
|
||||
if s.speeds[0] != 5 {
|
||||
t.Errorf("expected first sample to be 5, got %f", s.speeds[0])
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestSpeedTracker_Concurrent(t *testing.T) {
|
||||
s := &speedTracker{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(v int) {
|
||||
defer wg.Done()
|
||||
s.Record(float64(v))
|
||||
s.Median() // concurrent read
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic, and should have reasonable state
|
||||
s.mu.Lock()
|
||||
if len(s.speeds) == 0 || len(s.speeds) > 100 {
|
||||
t.Errorf("unexpected speeds length: %d", len(s.speeds))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStreamHasher_Sequential(t *testing.T) {
|
||||
// Create temp file
|
||||
f, err := os.CreateTemp("", "streamhasher_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
// Write test data
|
||||
data := []byte("hello world, this is a test of the stream hasher")
|
||||
if _, err := f.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create parts
|
||||
parts := []*blobDownloadPart{
|
||||
{Offset: 0, Size: int64(len(data))},
|
||||
}
|
||||
|
||||
sh := newStreamHasher(f, parts, int64(len(data)))
|
||||
|
||||
// Mark complete and run
|
||||
sh.MarkComplete(0)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
<-done
|
||||
|
||||
// Verify digest
|
||||
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
|
||||
if got := sh.Digest(); got != expected {
|
||||
t.Errorf("digest mismatch: got %s, want %s", got, expected)
|
||||
}
|
||||
|
||||
if err := sh.Err(); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHasher_OutOfOrderCompletion(t *testing.T) {
|
||||
// Create temp file
|
||||
f, err := os.CreateTemp("", "streamhasher_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
// Write test data (3 parts of 10 bytes each)
|
||||
data := []byte("0123456789ABCDEFGHIJabcdefghij")
|
||||
if _, err := f.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create 3 parts
|
||||
parts := []*blobDownloadPart{
|
||||
{N: 0, Offset: 0, Size: 10},
|
||||
{N: 1, Offset: 10, Size: 10},
|
||||
{N: 2, Offset: 20, Size: 10},
|
||||
}
|
||||
|
||||
sh := newStreamHasher(f, parts, int64(len(data)))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Mark parts complete out of order: 2, 0, 1
|
||||
sh.MarkComplete(2)
|
||||
sh.MarkComplete(0) // This should trigger hashing of part 0
|
||||
sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2
|
||||
|
||||
<-done
|
||||
|
||||
// Verify digest
|
||||
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
|
||||
if got := sh.Digest(); got != expected {
|
||||
t.Errorf("digest mismatch: got %s, want %s", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHasher_Stop(t *testing.T) {
|
||||
// Create temp file
|
||||
f, err := os.CreateTemp("", "streamhasher_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
parts := []*blobDownloadPart{
|
||||
{Offset: 0, Size: 100},
|
||||
}
|
||||
|
||||
sh := newStreamHasher(f, parts, 100)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Stop without completing any parts
|
||||
sh.Stop()
|
||||
<-done
|
||||
|
||||
// Should exit cleanly without error
|
||||
if err := sh.Err(); err != nil {
|
||||
t.Errorf("unexpected error after Stop: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHasher_HashedProgress(t *testing.T) {
|
||||
// Create temp file with known data
|
||||
f, err := os.CreateTemp("", "streamhasher_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
data := make([]byte, 1000)
|
||||
rand.Read(data)
|
||||
if _, err := f.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
parts := []*blobDownloadPart{
|
||||
{N: 0, Offset: 0, Size: 500},
|
||||
{N: 1, Offset: 500, Size: 500},
|
||||
}
|
||||
|
||||
sh := newStreamHasher(f, parts, 1000)
|
||||
|
||||
// Initially no progress
|
||||
if got := sh.Hashed(); got != 0 {
|
||||
t.Errorf("expected 0 hashed initially, got %d", got)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Complete part 0
|
||||
sh.MarkComplete(0)
|
||||
|
||||
// Give hasher time to process
|
||||
for i := 0; i < 100; i++ {
|
||||
if sh.Hashed() >= 500 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Complete part 1
|
||||
sh.MarkComplete(1)
|
||||
<-done
|
||||
|
||||
if got := sh.Hashed(); got != 1000 {
|
||||
t.Errorf("expected 1000 hashed, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSpeedTracker_Record(b *testing.B) {
|
||||
s := &speedTracker{}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Record(float64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSpeedTracker_Median(b *testing.B) {
|
||||
s := &speedTracker{}
|
||||
// Pre-populate with 100 samples
|
||||
for i := 0; i < 100; i++ {
|
||||
s.Record(float64(i))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Median()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStreamHasher(b *testing.B) {
|
||||
// Create temp file with test data
|
||||
f, err := os.CreateTemp("", "streamhasher_bench")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
size := 64 * 1024 * 1024 // 64MB
|
||||
data := make([]byte, size)
|
||||
rand.Read(data)
|
||||
if _, err := f.Write(data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
parts := []*blobDownloadPart{
|
||||
{Offset: 0, Size: int64(size)},
|
||||
}
|
||||
|
||||
b.SetBytes(int64(size))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
sh := newStreamHasher(f, parts, int64(size))
|
||||
sh.MarkComplete(0)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(done)
|
||||
}()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHashThroughput(b *testing.B) {
|
||||
// Baseline: raw SHA256 throughput on this machine
|
||||
size := 256 * 1024 * 1024 // 256MB
|
||||
data := make([]byte, size)
|
||||
rand.Read(data)
|
||||
|
||||
b.SetBytes(int64(size))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := sha256.New()
|
||||
h.Write(data)
|
||||
h.Sum(nil)
|
||||
}
|
||||
}
|
||||
@@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
_, err := downloadBlob(ctx, downloadOpts{
|
||||
mp: mp,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
@@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
skipVerify[layer.Digest] = cacheHit
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
if skipVerify[layer.Digest] {
|
||||
continue
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
// something went wrong, delete the blob
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
// log this, but return the original error
|
||||
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Note: Digest verification now happens inline during download in blobDownload.run()
|
||||
// via the orderedWriter, so no separate verification pass is needed.
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
|
||||
52
server/internal/cache/blob/cache.go
vendored
52
server/internal/cache/blob/cache.go
vendored
@@ -10,7 +10,6 @@ import (
|
||||
"hash"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -327,21 +326,19 @@ func (c *DiskCache) GetFile(d Digest) string {
|
||||
return absJoin(c.dir, "blobs", filename)
|
||||
}
|
||||
|
||||
// Links returns a sequence of link names. The sequence is in lexical order.
|
||||
// Links returns a slice of link names in lexical order.
|
||||
// Names are converted from their relative path form to their name form but are
|
||||
// not guaranteed to be valid. Callers should validate the names before using.
|
||||
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
for path, err := range c.links() {
|
||||
if err != nil {
|
||||
yield("", err)
|
||||
return
|
||||
}
|
||||
if !yield(pathToName(path), nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
func (c *DiskCache) Links() ([]string, error) {
|
||||
paths, err := c.links()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
names := make([]string, len(paths))
|
||||
for i, path := range paths {
|
||||
names[i] = pathToName(path)
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// pathToName converts a path to a name. It is the inverse of nameToPath. The
|
||||
@@ -372,10 +369,11 @@ func (c *DiskCache) manifestPath(name string) (string, error) {
|
||||
}
|
||||
|
||||
maybe := filepath.Join("manifests", np)
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
paths, err := c.links()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, l := range paths {
|
||||
if strings.EqualFold(maybe, l) {
|
||||
return filepath.Join(c.dir, l), nil
|
||||
}
|
||||
@@ -383,22 +381,10 @@ func (c *DiskCache) manifestPath(name string) (string, error) {
|
||||
return filepath.Join(c.dir, maybe), nil
|
||||
}
|
||||
|
||||
// links returns a sequence of links in the cache in lexical order.
|
||||
func (c *DiskCache) links() iter.Seq2[string, error] {
|
||||
// TODO(bmizerany): reuse empty dirnames if exist
|
||||
return func(yield func(string, error) bool) {
|
||||
fsys := os.DirFS(c.dir)
|
||||
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
|
||||
if err != nil {
|
||||
yield("", err)
|
||||
return
|
||||
}
|
||||
for _, manifest := range manifests {
|
||||
if !yield(manifest, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// links returns a slice of link paths in the cache in lexical order.
|
||||
func (c *DiskCache) links() ([]string, error) {
|
||||
fsys := os.DirFS(c.dir)
|
||||
return fs.Glob(fsys, "manifests/*/*/*/*")
|
||||
}
|
||||
|
||||
type checkWriter struct {
|
||||
|
||||
27
server/internal/cache/blob/cache_test.go
vendored
27
server/internal/cache/blob/cache_test.go
vendored
@@ -466,12 +466,9 @@ func testManifestNameReuse(t *testing.T) {
|
||||
t.Fatalf("g = %v, want %v", g, w)
|
||||
}
|
||||
|
||||
var got []string
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
got, err := c.links()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := []string{"manifests/h/n/m/t"}
|
||||
if !slices.Equal(got, want) {
|
||||
@@ -487,12 +484,9 @@ func testManifestNameReuse(t *testing.T) {
|
||||
err = c.Link("h/n/m:T", d1)
|
||||
check(err)
|
||||
|
||||
got = got[:0]
|
||||
for l, err := range c.links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
got, err = c.links()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// we should have only one link that is same case as the last link
|
||||
@@ -554,12 +548,9 @@ func TestNames(t *testing.T) {
|
||||
check(c.Link("h/n/m:t", mkdigest("1")))
|
||||
check(c.Link("h/n/m:u", mkdigest("2")))
|
||||
|
||||
var got []string
|
||||
for l, err := range c.Links() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = append(got, l)
|
||||
got, err := c.Links()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := []string{"h/n/m:t", "h/n/m:u"}
|
||||
if !slices.Equal(got, want) {
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -546,18 +545,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
})
|
||||
}()
|
||||
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
if err != nil {
|
||||
// Note the chunksum stream
|
||||
// interruption, but do not cancel
|
||||
// in-flight downloads. We can still
|
||||
// make progress on them. Once they are
|
||||
// done, ErrIncomplete will be returned
|
||||
// below.
|
||||
update(0, err)
|
||||
break
|
||||
}
|
||||
|
||||
err = r.chunksums(ctx, name, l, func(cs chunksum) bool {
|
||||
cacheKey := fmt.Sprintf(
|
||||
"v1 pull chunksum %s %s %d-%d",
|
||||
l.Digest,
|
||||
@@ -569,7 +557,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
_, err := c.Get(cacheKeyDigest)
|
||||
if err == nil {
|
||||
update(cs.Chunk.Size(), ErrCached)
|
||||
continue
|
||||
return true // continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
@@ -620,6 +608,13 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
// Record the downloading of this chunk.
|
||||
return blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
})
|
||||
return true // continue processing chunks
|
||||
})
|
||||
if err != nil {
|
||||
// Note the chunksum stream interruption, but do not cancel
|
||||
// in-flight downloads. We can still make progress on them.
|
||||
// Once they are done, ErrIncomplete will be returned below.
|
||||
update(0, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -674,19 +669,6 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manifest) All() iter.Seq[*Layer] {
|
||||
return func(yield func(*Layer) bool) {
|
||||
if !yield(m.Config) {
|
||||
return
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if !yield(l) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manifest) Size() int64 {
|
||||
var size int64
|
||||
if m.Config != nil {
|
||||
@@ -811,125 +793,114 @@ type chunksum struct {
|
||||
Digest blob.Digest
|
||||
}
|
||||
|
||||
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
|
||||
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
|
||||
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
|
||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
|
||||
return func(yield func(chunksum, error) bool) {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
// chunksums calls fn for each chunksum in the layer. If the layer is under the
|
||||
// chunking threshold, a single chunksum covering the entire layer is passed to fn.
|
||||
// If the layer is over the chunking threshold, chunksums are read from the chunksums endpoint.
|
||||
// Returns an error if the chunksum stream fails, or nil if all chunksums were processed.
|
||||
// If fn returns false, iteration stops early and chunksums returns nil.
|
||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer, fn func(chunksum) bool) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if l.Size < r.maxChunkingThreshold() {
|
||||
// any layer under the threshold should be downloaded
|
||||
// in one go.
|
||||
cs := chunksum{
|
||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
),
|
||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
||||
Digest: l.Digest,
|
||||
}
|
||||
fn(cs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// The response is a sequence of chunksums.
|
||||
//
|
||||
// Chunksums are chunks of a larger blob that can be
|
||||
// downloaded and verified independently.
|
||||
//
|
||||
// The chunksums endpoint is a GET request that returns a
|
||||
// sequence of chunksums in the following format:
|
||||
//
|
||||
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// < HTTP/1.1 200 OK
|
||||
// < Content-Location: <blobURL>
|
||||
// <
|
||||
// < <digest> <start>-<end>
|
||||
// < ...
|
||||
//
|
||||
// The <blobURL> is the URL to download the chunks from and
|
||||
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||
// is the range the chunk in the blob.
|
||||
//
|
||||
// Ranges may be used directly in Range headers like
|
||||
// "bytes=<start>-<end>".
|
||||
//
|
||||
// The chunksums returned are guaranteed to be contiguous and
|
||||
// include all bytes of the layer. If the stream is cut short,
|
||||
// clients should retry.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
)
|
||||
|
||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
return fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
||||
}
|
||||
blobURL := res.Header.Get("Content-Location")
|
||||
|
||||
s := bufio.NewScanner(res.Body)
|
||||
s.Split(bufio.ScanWords)
|
||||
for {
|
||||
if !s.Scan() {
|
||||
return s.Err()
|
||||
}
|
||||
d, err := blob.ParseDigest(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
return fmt.Errorf("invalid digest: %q", s.Bytes())
|
||||
}
|
||||
|
||||
if l.Size < r.maxChunkingThreshold() {
|
||||
// any layer under the threshold should be downloaded
|
||||
// in one go.
|
||||
cs := chunksum{
|
||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
),
|
||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
||||
Digest: l.Digest,
|
||||
if !s.Scan() {
|
||||
err := s.Err()
|
||||
if err == nil {
|
||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
||||
}
|
||||
yield(cs, nil)
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
// The response is a sequence of chunksums.
|
||||
//
|
||||
// Chunksums are chunks of a larger blob that can be
|
||||
// downloaded and verified independently.
|
||||
//
|
||||
// The chunksums endpoint is a GET request that returns a
|
||||
// sequence of chunksums in the following format:
|
||||
//
|
||||
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// < HTTP/1.1 200 OK
|
||||
// < Content-Location: <blobURL>
|
||||
// <
|
||||
// < <digest> <start>-<end>
|
||||
// < ...
|
||||
//
|
||||
// The <blobURL> is the URL to download the chunks from and
|
||||
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||
// is the range the chunk in the blob.
|
||||
//
|
||||
// Ranges may be used directly in Range headers like
|
||||
// "bytes=<start>-<end>".
|
||||
//
|
||||
// The chunksums returned are guaranteed to be contiguous and
|
||||
// include all bytes of the layer. If the stream is cut short,
|
||||
// clients should retry.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
)
|
||||
|
||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
||||
chunk, err := parseChunk(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
return fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes())
|
||||
}
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
|
||||
cs := chunksum{
|
||||
URL: blobURL,
|
||||
Chunk: chunk,
|
||||
Digest: d,
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
blobURL := res.Header.Get("Content-Location")
|
||||
|
||||
s := bufio.NewScanner(res.Body)
|
||||
s.Split(bufio.ScanWords)
|
||||
for {
|
||||
if !s.Scan() {
|
||||
if s.Err() != nil {
|
||||
yield(chunksum{}, s.Err())
|
||||
}
|
||||
return
|
||||
}
|
||||
d, err := blob.ParseDigest(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
if !s.Scan() {
|
||||
err := s.Err()
|
||||
if err == nil {
|
||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
||||
}
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
chunk, err := parseChunk(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
cs := chunksum{
|
||||
URL: blobURL,
|
||||
Chunk: chunk,
|
||||
Digest: d,
|
||||
}
|
||||
if !yield(cs, nil) {
|
||||
return
|
||||
}
|
||||
if !fn(cs) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1176,8 +1147,8 @@ func splitExtended(s string) (scheme, name, digest string) {
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
// parseChunk parses a string in the form "start-end" and returns the Chunk.
|
||||
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
||||
// parseChunk parses a byte slice in the form "start-end" and returns the Chunk.
|
||||
func parseChunk(s []byte) (blob.Chunk, error) {
|
||||
startPart, endPart, found := strings.Cut(string(s), "-")
|
||||
if !found {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
|
||||
|
||||
@@ -27,46 +27,20 @@ type Trace struct {
|
||||
}
|
||||
|
||||
func (t *Trace) update(l *Layer, n int64, err error) {
|
||||
if t.Update != nil {
|
||||
if t != nil && t.Update != nil {
|
||||
t.Update(l, n, err)
|
||||
}
|
||||
}
|
||||
|
||||
type traceKey struct{}
|
||||
|
||||
// WithTrace adds a trace to the context for transfer progress reporting.
|
||||
// WithTrace attaches a Trace to the context for transfer progress reporting.
|
||||
func WithTrace(ctx context.Context, t *Trace) context.Context {
|
||||
old := traceFromContext(ctx)
|
||||
if old == t {
|
||||
// No change, return the original context. This also prevents
|
||||
// infinite recursion below, if the caller passes the same
|
||||
// Trace.
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Create a new Trace that wraps the old one, if any. If we used the
|
||||
// same pointer t, we end up with a recursive structure.
|
||||
composed := &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if old != nil {
|
||||
old.update(l, n, err)
|
||||
}
|
||||
t.update(l, n, err)
|
||||
},
|
||||
}
|
||||
return context.WithValue(ctx, traceKey{}, composed)
|
||||
return context.WithValue(ctx, traceKey{}, t)
|
||||
}
|
||||
|
||||
var emptyTrace = &Trace{}
|
||||
|
||||
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
|
||||
// none is found.
|
||||
//
|
||||
// It never returns nil.
|
||||
// traceFromContext returns the Trace associated with ctx, or nil if none.
|
||||
func traceFromContext(ctx context.Context) *Trace {
|
||||
t, _ := ctx.Value(traceKey{}).(*Trace)
|
||||
if t == nil {
|
||||
return emptyTrace
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -2,44 +2,46 @@ package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
var n int
|
||||
return func(yield func(int, error) bool) {
|
||||
var t *time.Timer
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
yield(n, ctx.Err())
|
||||
return
|
||||
}
|
||||
// Retry calls fn repeatedly with exponential backoff until it returns nil,
|
||||
// a non-retryable error (shouldRetry returns false), or the context is cancelled.
|
||||
// The shouldRetry function determines if an error is retryable.
|
||||
// Returns the last error encountered, or nil if fn succeeded.
|
||||
func Retry(ctx context.Context, maxBackoff time.Duration, shouldRetry func(error) bool, fn func() error) error {
|
||||
var t *time.Timer
|
||||
for n := 0; ; n++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !yield(n, nil) {
|
||||
return
|
||||
}
|
||||
err := fn()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !shouldRetry(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
n++
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
if t == nil {
|
||||
t = time.NewTimer(d)
|
||||
} else {
|
||||
t.Reset(d)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
case <-t.C:
|
||||
}
|
||||
if t == nil {
|
||||
t = time.NewTimer(d)
|
||||
} else {
|
||||
t.Reset(d)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return ctx.Err()
|
||||
case <-t.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,31 +10,70 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoop(t *testing.T) {
|
||||
func TestRetry(t *testing.T) {
|
||||
synctest.Run(func() {
|
||||
last := -1
|
||||
n := 0
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
for n, err := range Loop(ctx, 100*time.Millisecond) {
|
||||
if !errors.Is(err, ctx.Err()) {
|
||||
t.Errorf("err = %v, want nil", err)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if n != last+1 {
|
||||
t.Errorf("n = %d, want %d", n, last+1)
|
||||
}
|
||||
last = n
|
||||
err := Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error {
|
||||
n++
|
||||
if n > 5 {
|
||||
cancel()
|
||||
}
|
||||
return errors.New("keep going")
|
||||
})
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("err = %v, want context.Canceled", err)
|
||||
}
|
||||
|
||||
if last != 6 {
|
||||
t.Errorf("last = %d, want 6", last)
|
||||
if n != 6 {
|
||||
t.Errorf("n = %d, want 6", n)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetrySuccess(t *testing.T) {
|
||||
synctest.Run(func() {
|
||||
n := 0
|
||||
err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool { return true }, func() error {
|
||||
n++
|
||||
if n >= 3 {
|
||||
return nil // success
|
||||
}
|
||||
return errors.New("retry")
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("err = %v, want nil", err)
|
||||
}
|
||||
if n != 3 {
|
||||
t.Errorf("n = %d, want 3", n)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryNonRetryable(t *testing.T) {
|
||||
synctest.Run(func() {
|
||||
permanent := errors.New("permanent error")
|
||||
n := 0
|
||||
err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool {
|
||||
return !errors.Is(err, permanent)
|
||||
}, func() error {
|
||||
n++
|
||||
if n >= 2 {
|
||||
return permanent
|
||||
}
|
||||
return errors.New("retry")
|
||||
})
|
||||
|
||||
if !errors.Is(err, permanent) {
|
||||
t.Errorf("err = %v, want permanent", err)
|
||||
}
|
||||
if n != 2 {
|
||||
t.Errorf("n = %d, want 2", n)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,37 +3,46 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoopAllocs(t *testing.T) {
|
||||
var errRetry = errors.New("retry")
|
||||
|
||||
func TestRetryAllocs(t *testing.T) {
|
||||
for i := range 3 {
|
||||
got := testing.AllocsPerRun(1000, func() {
|
||||
for tick := range Loop(t.Context(), 1) {
|
||||
tick := 0
|
||||
Retry(t.Context(), 1, func(err error) bool { return true }, func() error {
|
||||
tick++
|
||||
if tick >= i {
|
||||
break
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errRetry
|
||||
})
|
||||
})
|
||||
want := float64(0)
|
||||
if i > 0 {
|
||||
want = 3 // due to time.NewTimer
|
||||
}
|
||||
if got > want {
|
||||
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
|
||||
t.Errorf("[%d ticks]: allocs = %v, want <= %v", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLoop(b *testing.B) {
|
||||
func BenchmarkRetry(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
synctest.Run(func() {
|
||||
for n := range Loop(ctx, 100*time.Millisecond) {
|
||||
n := 0
|
||||
Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error {
|
||||
n++
|
||||
if n == b.N {
|
||||
break
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errRetry
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "DELETE" {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
p, err := decodeParams(r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
p, err := decodeParams(r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -293,10 +293,14 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
t := time.NewTicker(1<<63 - 1) // "unstarted" timer
|
||||
// ticker controls periodic progress flushing. It starts paused (very long
|
||||
// interval) and is activated by start() once all layers are registered,
|
||||
// so clients see a complete total before progress begins.
|
||||
ticker := time.NewTicker(1 << 62) // effectively paused until started
|
||||
defer ticker.Stop()
|
||||
start := sync.OnceFunc(func() {
|
||||
flushProgress() // flush initial state
|
||||
t.Reset(100 * time.Millisecond)
|
||||
flushProgress()
|
||||
ticker.Reset(100 * time.Millisecond)
|
||||
})
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
@@ -320,36 +324,21 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
})
|
||||
}()
|
||||
|
||||
// Block flushing progress updates until every
|
||||
// layer is accounted for. Clients depend on a
|
||||
// complete model size to calculate progress
|
||||
// correctly; if they use an incomplete total,
|
||||
// progress indicators would erratically jump
|
||||
// as new layers are registered.
|
||||
start()
|
||||
},
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() (err error) {
|
||||
defer func() { done <- err }()
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := s.Client.Pull(ctx, p.model())
|
||||
if canRetry(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
go func() {
|
||||
done <- backoff.Retry(ctx, 3*time.Second, canRetry, func() error {
|
||||
return s.Client.Pull(ctx, p.model())
|
||||
})
|
||||
}()
|
||||
|
||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
case <-ticker.C:
|
||||
flushProgress()
|
||||
case err := <-done:
|
||||
flushProgress()
|
||||
@@ -374,20 +363,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
var v T
|
||||
err := json.NewDecoder(r).Decode(&v)
|
||||
func decodeParams(r io.Reader) (*params, error) {
|
||||
var p params
|
||||
err := json.NewDecoder(r).Decode(&p)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
return &p, nil
|
||||
}
|
||||
var zero T
|
||||
|
||||
// Not sure why, but I can't seem to be able to use:
|
||||
//
|
||||
// errors.As(err, &json.UnmarshalTypeError{})
|
||||
//
|
||||
// This is working fine in stdlib, so I'm not sure what rules changed
|
||||
// and why this no longer works here. So, we do it the verbose way.
|
||||
var a *json.UnmarshalTypeError
|
||||
var b *json.SyntaxError
|
||||
if errors.As(err, &a) || errors.As(err, &b) {
|
||||
@@ -396,7 +378,7 @@ func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
|
||||
}
|
||||
return zero, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func canRetry(err error) bool {
|
||||
@@ -408,10 +390,8 @@ func canRetry(err error) bool {
|
||||
return oe.Temporary()
|
||||
}
|
||||
s := err.Error()
|
||||
return cmp.Or(
|
||||
errors.Is(err, context.DeadlineExceeded),
|
||||
strings.Contains(s, "unreachable"),
|
||||
strings.Contains(s, "no route to host"),
|
||||
strings.Contains(s, "connection reset by peer"),
|
||||
)
|
||||
return errors.Is(err, context.DeadlineExceeded) ||
|
||||
strings.Contains(s, "unreachable") ||
|
||||
strings.Contains(s, "no route to host") ||
|
||||
strings.Contains(s, "connection reset by peer")
|
||||
}
|
||||
|
||||
@@ -1106,6 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
Requires: m.Config.Requires,
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
|
||||
@@ -363,7 +363,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather information\",\"parameters\":{\"type\":\"\"}}}]user: Get the weather\n",
|
||||
expectTemplate: "[{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather information\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]user: Get the weather\n",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import "os"
|
||||
|
||||
func setSparse(*os.File) {
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func setSparse(file *os.File) {
|
||||
// exFat (and other FS types) don't support sparse files, so ignore errors
|
||||
windows.DeviceIoControl( //nolint:errcheck
|
||||
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
|
||||
nil, 0,
|
||||
nil, 0,
|
||||
nil, nil,
|
||||
)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ type ConfigV2 struct {
|
||||
FileType string `json:"file_type"` // shown as Quantization Level
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
Requires string `json:"requires,omitempty"`
|
||||
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user