Compare commits

...

3 Commits

Author SHA1 Message Date
Michael Yang
d05fc26570 null truncate 2025-08-25 10:00:16 -07:00
Michael Yang
c457628090 null stream 2025-08-25 10:00:15 -07:00
Michael Yang
e914477bb6 types: add types.Null[T]
there's a common pattern where request fields may need to differentiate
between an unset value and a value set to the type's zero value. this is
commonly used to apply a different default value, e.g. stream, or to
omit a field entirely, e.g. think.

similar to sql.Null[T], types.Null[T] simplifies this by providing
utilities to quickly and easily apply this pattern to any type using
generics.
2025-08-25 09:49:02 -07:00
11 changed files with 190 additions and 94 deletions

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types"
"github.com/ollama/ollama/types/model"
)
@@ -64,7 +65,7 @@ type GenerateRequest struct {
Context []int `json:"context,omitempty"`
// Stream specifies whether the response is streaming; it is true by default.
Stream *bool `json:"stream,omitempty"`
Stream types.Null[bool] `json:"stream,omitempty"`
// Raw set to true means that no formatting will be applied to the prompt.
Raw bool `json:"raw,omitempty"`
@@ -105,7 +106,7 @@ type ChatRequest struct {
Messages []Message `json:"messages"`
// Stream enables streaming of returned responses; true by default.
Stream *bool `json:"stream,omitempty"`
Stream types.Null[bool] `json:"stream,omitempty"`
// Format is the format to return the response in (e.g. "json").
Format json.RawMessage `json:"format,omitempty"`
@@ -381,7 +382,7 @@ type EmbedRequest struct {
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
Truncate types.Null[bool] `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
@@ -420,9 +421,9 @@ type EmbeddingResponse struct {
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
Model string `json:"model"`
Stream *bool `json:"stream,omitempty"`
Quantize string `json:"quantize,omitempty"`
Model string `json:"model"`
Stream types.Null[bool] `json:"stream,omitempty"`
Quantize string `json:"quantize,omitempty"`
From string `json:"from,omitempty"`
Files map[string]string `json:"files,omitempty"`
@@ -486,11 +487,11 @@ type CopyRequest struct {
// PullRequest is the request passed to [Client.Pull].
type PullRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` // Deprecated: ignored
Stream *bool `json:"stream,omitempty"`
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` // Deprecated: ignored
Stream types.Null[bool] `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -507,11 +508,11 @@ type ProgressResponse struct {
// PushRequest is the request passed to [Client.Push].
type PushRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` // Deprecated: ignored
Stream types.Null[bool] `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`

View File

@@ -17,6 +17,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types"
"github.com/ollama/ollama/types/model"
)
@@ -571,7 +572,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Stream: types.NullWithValue(r.Stream),
Tools: r.Tools,
Think: think,
}, nil
@@ -650,7 +651,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Stream: types.NullWithValue(r.Stream),
Suffix: r.Suffix,
}, nil
}

View File

@@ -146,7 +146,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- api.ProgressResponse{Status: "success"}
}()
if r.Stream != nil && !*r.Stream {
if !r.Stream.Value(true) {
waitForStream(c, ch)
return
}

View File

@@ -440,7 +440,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}()
if req.Stream != nil && !*req.Stream {
if !req.Stream.Value(true) {
var r api.GenerateResponse
var sbThinking strings.Builder
var sbContent strings.Builder
@@ -487,12 +487,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string
switch i := req.Input.(type) {
@@ -541,6 +535,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
var count int
truncate := req.Truncate.Value(true)
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
@@ -701,7 +696,7 @@ func (s *Server) PullHandler(c *gin.Context) {
}
}()
if req.Stream != nil && !*req.Stream {
if !req.Stream.Value(true) {
waitForStream(c, ch)
return
}
@@ -756,7 +751,7 @@ func (s *Server) PushHandler(c *gin.Context) {
}
}()
if req.Stream != nil && !*req.Stream {
if !req.Stream.Value(true) {
waitForStream(c, ch)
return
}
@@ -1775,7 +1770,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}()
if req.Stream != nil && !*req.Stream {
if !req.Stream.Value(true) {
var resp api.ChatResponse
var toolCalls []api.ToolCall
var sbThinking strings.Builder

View File

@@ -22,8 +22,6 @@ import (
"github.com/ollama/ollama/fs/ggml"
)
var stream bool = false
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
t.Helper()
t.Setenv("OLLAMA_MODELS", cmp.Or(os.Getenv("OLLAMA_MODELS"), t.TempDir()))
@@ -118,7 +116,7 @@ func TestCreateFromBin(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -148,7 +146,7 @@ func TestCreateFromModel(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -162,7 +160,7 @@ func TestCreateFromModel(t *testing.T) {
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2",
From: "test",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -192,7 +190,7 @@ func TestCreateRemovesLayers(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ .Prompt }}",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -213,7 +211,7 @@ func TestCreateRemovesLayers(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ .System }} {{ .Prompt }}",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -243,7 +241,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
System: "Say hi!",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -264,7 +262,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
System: "",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -297,7 +295,7 @@ func TestCreateMergeParameters(t *testing.T) {
"top_k": 10,
"stop": []string{"USER:", "ASSISTANT:"},
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -322,7 +320,7 @@ func TestCreateMergeParameters(t *testing.T) {
"temperature": 0.6,
"top_p": 0.7,
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -381,7 +379,7 @@ func TestCreateMergeParameters(t *testing.T) {
"top_p": 0.7,
"stop": []string{"<|endoftext|>"},
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -441,7 +439,7 @@ func TestCreateReplacesMessages(t *testing.T) {
Content: "Oh, my god.",
},
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -475,7 +473,7 @@ func TestCreateReplacesMessages(t *testing.T) {
Content: "A test. And a thumping good one at that, I'd wager.",
},
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -536,7 +534,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Files: map[string]string{"test.gguf": digest},
Template: "{{ .System }} {{ .Prompt }}",
System: "Say bye!",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -578,7 +576,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ .Prompt",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusBadRequest {
@@ -592,7 +590,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ if .Prompt }}",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusBadRequest {
@@ -606,7 +604,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ Prompt }}",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusBadRequest {
@@ -627,7 +625,7 @@ func TestCreateLicenses(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
License: []string{"MIT", "Apache-2.0"},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -678,7 +676,7 @@ func TestCreateDetectTemplate(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -698,7 +696,7 @@ func TestCreateDetectTemplate(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/types"
)
func TestGenerateDebugRenderOnly(t *testing.T) {
@@ -53,7 +54,6 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
go s.sched.Run(t.Context())
// Create a test model
stream := false
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
@@ -82,7 +82,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
Model: "test-model",
Files: map[string]string{"file.gguf": digest},
Template: "{{ .Prompt }}",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -172,7 +172,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
}
t.Run(tt.name+streamSuffix, func(t *testing.T) {
req := tt.request
req.Stream = &stream
req.Stream = types.NullWithValue(stream)
w := createRequest(t, s.GenerateHandler, req)
if tt.expectDebug {
@@ -246,7 +246,6 @@ func TestChatDebugRenderOnly(t *testing.T) {
go s.sched.Run(t.Context())
// Create a test model
stream := false
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
@@ -275,7 +274,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
Model: "test-model",
Files: map[string]string{"file.gguf": digest},
Template: "{{ if .Tools }}{{ .Tools }}{{ end }}{{ range .Messages }}{{ .Role }}: {{ .Content }}\n{{ end }}",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -377,7 +376,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
}
t.Run(tt.name+streamSuffix, func(t *testing.T) {
req := tt.request
req.Stream = &stream
req.Stream = types.NullWithValue(stream)
w := createRequest(t, s.ChatHandler, req)
if tt.expectDebug {

View File

@@ -126,7 +126,7 @@ func TestGenerateChat(t *testing.T) {
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
{{ end }}`,
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -182,7 +182,7 @@ func TestGenerateChat(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert",
Files: map[string]string{"bert.gguf": digest},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -288,7 +288,7 @@ func TestGenerateChat(t *testing.T) {
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -318,7 +318,7 @@ func TestGenerateChat(t *testing.T) {
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -340,7 +340,7 @@ func TestGenerateChat(t *testing.T) {
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -363,7 +363,7 @@ func TestGenerateChat(t *testing.T) {
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Help me write tests."},
},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -422,15 +422,13 @@ func TestGenerateChat(t *testing.T) {
EvalDuration: 1,
}
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: &streamRequest,
Stream: streamTrue,
})
if w.Code != http.StatusOK {
@@ -551,7 +549,7 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: &stream,
Stream: streamFalse,
})
wg.Wait()
@@ -666,7 +664,7 @@ func TestGenerate(t *testing.T) {
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}
`,
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -704,7 +702,7 @@ func TestGenerate(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert",
Files: map[string]string{"file.gguf": digest},
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -825,7 +823,7 @@ func TestGenerate(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -853,7 +851,7 @@ func TestGenerate(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -873,7 +871,7 @@ func TestGenerate(t *testing.T) {
Model: "test-system",
Prompt: "Hello!",
System: "You can perform magic tricks.",
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -895,7 +893,7 @@ func TestGenerate(t *testing.T) {
Template: `{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -957,7 +955,7 @@ func TestGenerate(t *testing.T) {
Model: "test-system",
Prompt: "Help me write tests.",
Raw: true,
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -1040,7 +1038,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
{{- if eq .Role "user" }}user: {{ .Content }}
{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}<think>{{ .Thinking }}</think>{{ end }}{{ .Content }}
{{ end }}{{ end }}<think>`,
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -1066,13 +1064,12 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
}
mock.CompletionFn = nil
streamRequest := false
req := api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{
{Role: "user", Content: userContent},
},
Stream: &streamRequest,
Stream: streamFalse,
}
if think {
req.Think = &api.ThinkValue{Value: think}
@@ -1165,7 +1162,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}},
Think: &api.ThinkValue{Value: think},
Stream: &stream,
Stream: streamFalse,
})
wg.Wait()

View File

@@ -291,12 +291,11 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
// Create a simple test model
_, digest := createHarmonyTestModel(t)
streamFalse := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test-streaming",
Files: map[string]string{"test.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &streamFalse,
Stream: streamFalse,
})
if w.Code != 200 {
@@ -304,11 +303,10 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test-streaming",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Stream: streamTrue,
Tools: getTestTools(),
})
@@ -441,12 +439,11 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
// Create model
_, digest := createHarmonyTestModel(t)
streamFalse := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "gpt-oss",
Files: map[string]string{"test.gguf": digest},
Template: `<|start|><|end|>{{ .Tools }}{{ .Prompt }}`,
Stream: &streamFalse,
Stream: streamFalse,
})
if w.Code != 200 {
@@ -454,11 +451,10 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
}
// Test streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "gpt-oss",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Stream: streamTrue,
Tools: getTestTools(),
})
@@ -625,12 +621,11 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
_, digest := createHarmonyTestModel(t)
// Create model with passthrough template
stream := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test",
Files: map[string]string{"file.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &stream,
Stream: streamFalse,
})
if w.Code != http.StatusOK {
@@ -638,11 +633,10 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Stream: streamTrue,
Tools: getTestTools(),
})

View File

@@ -28,10 +28,16 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
var (
streamFalse = types.NullWithValue(false)
streamTrue = types.NullWithValue(true)
)
func createTestFile(t *testing.T, name string) (string, string) {
t.Helper()
@@ -332,11 +338,10 @@ func TestRoutes(t *testing.T) {
Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) {
_, digest := createTestFile(t, "ollama-model")
stream := false
createReq := api.CreateRequest{
Name: "t-bone",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
Stream: streamFalse,
}
jsonData, err := json.Marshal(createReq)
if err != nil {
@@ -638,7 +643,7 @@ func TestManifestCaseSensitivity(t *testing.T) {
// version.
Name: wantStableName,
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
Stream: streamFalse,
}))
checkManifestList()
@@ -646,14 +651,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
Name: name(),
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
Stream: streamFalse,
}))
checkManifestList()
t.Logf("pulling")
checkOK(createRequest(t, s.PullHandler, api.PullRequest{
Name: name(),
Stream: &stream,
Stream: streamFalse,
Insecure: true,
}))
checkManifestList()

53
types/null.go Normal file
View File

@@ -0,0 +1,53 @@
package types
import (
"encoding/json"
)
// Null represents a value of any type T that may be null.
type Null[T any] struct {
value T
valid bool
}
// NullWithValue creates a new, valid Null[T].
func NullWithValue[T any](value T) Null[T] {
return Null[T]{value: value, valid: true}
}
// Value returns the value of the Type[T] if set, otherwise it returns the provided default value or the zero value of T.
func (n Null[T]) Value(defaultValue ...T) T {
if n.valid {
return n.value
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
var zero T
return zero
}
// SetValue sets the value of the Type[T].
func (n *Null[T]) SetValue(t T) {
n.value = t
n.valid = true
}
// MarshalJSON implements [json.Marshaler].
func (n Null[T]) MarshalJSON() ([]byte, error) {
if n.valid {
return json.Marshal(n.value)
}
return []byte("null"), nil
}
// UnmarshalJSON implements [json.Unmarshaler].
func (n *Null[T]) UnmarshalJSON(data []byte) error {
if string(data) != "null" {
if err := json.Unmarshal(data, &n.value); err != nil {
return err
}
n.valid = true
}
return nil
}

53
types/null_test.go Normal file
View File

@@ -0,0 +1,53 @@
package types_test
import (
"encoding/json"
"testing"
"github.com/ollama/ollama/types"
)
func TestNull(t *testing.T) {
var s types.Null[string]
if val := s.Value(); val != "" {
t.Errorf("expected Value to return zero value '', got '%s'", val)
}
if val := s.Value("default"); val != "default" {
t.Errorf("expected Value to return default value 'default', got '%s'", val)
}
if bts, err := json.Marshal(s); err != nil {
t.Errorf("unexpected error during MarshalJSON: %v", err)
} else if want := "null"; string(bts) != want {
t.Errorf("expected marshaled JSON to be %s, got %s", want, string(bts))
}
s.SetValue("foo")
if val := s.Value(); val != "foo" {
t.Errorf("expected Value to return 'foo', got '%s'", val)
}
s = types.NullValue("bar")
if val := s.Value(); val != "bar" {
t.Errorf("expected Value to return 'bar', got '%s'", val)
}
if bts, err := json.Marshal(s); err != nil {
t.Errorf("unexpected error during MarshalJSON: %v", err)
} else if want := `"bar"`; string(bts) != want {
t.Errorf("expected marshaled JSON to be %s, got %s", want, string(bts))
}
if err := json.Unmarshal([]byte(`null`), &s); err != nil {
t.Errorf("unexpected error during UnmarshalJSON: %v", err)
}
if err := json.Unmarshal([]byte(`"baz"`), &s); err != nil {
t.Errorf("unexpected error during UnmarshalJSON: %v", err)
}
if err := json.Unmarshal([]byte(`1.2345`), &s); err == nil {
t.Error("expected error during UnmarshalJSON with invalid JSON, got nil")
}
}