mirror of
https://github.com/ollama/ollama.git
synced 2026-01-17 20:11:14 -05:00
Compare commits
1 Commits
fix-nemotr
...
imagegen-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e56dab90b |
33
api/types.go
33
api/types.go
@@ -127,6 +127,25 @@ type GenerateRequest struct {
|
||||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// Experimental: Image generation fields (may change or be removed)
|
||||
|
||||
// Width is the width of the generated image in pixels.
|
||||
// Only used for image generation models.
|
||||
Width int32 `json:"width,omitempty"`
|
||||
|
||||
// Height is the height of the generated image in pixels.
|
||||
// Only used for image generation models.
|
||||
Height int32 `json:"height,omitempty"`
|
||||
|
||||
// Steps is the number of diffusion steps for image generation.
|
||||
// Only used for image generation models.
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
|
||||
// Seed is the random seed for reproducible image generation.
|
||||
// If 0 or not specified, a random seed will be used.
|
||||
// Only used for image generation models.
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -860,6 +879,20 @@ type GenerateResponse struct {
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Experimental: Image generation fields (may change or be removed)
|
||||
|
||||
// Images contains base64-encoded generated images.
|
||||
// Only present for image generation models.
|
||||
Images []string `json:"images,omitempty"`
|
||||
|
||||
// Completed is the number of completed steps in image generation.
|
||||
// Only present for image generation models during streaming.
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
|
||||
// Total is the total number of steps for image generation.
|
||||
// Only present for image generation models during streaming.
|
||||
Total int64 `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
||||
@@ -600,7 +600,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImage) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
@@ -1985,6 +1985,7 @@ func NewCLI() *cobra.Command {
|
||||
} {
|
||||
switch cmd {
|
||||
case runCmd:
|
||||
imagegen.AppendFlagsDocs(cmd)
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
|
||||
case serveCmd:
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||
|
||||
@@ -1555,7 +1555,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
|
||||
63
docs/api.md
63
docs/api.md
@@ -16,6 +16,7 @@
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Version](#version)
|
||||
- [Experimental: Image Generation](#image-generation-experimental)
|
||||
|
||||
## Conventions
|
||||
|
||||
@@ -58,6 +59,16 @@ Advanced parameters (optional):
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
|
||||
Experimental image generation parameters (for image generation models only):
|
||||
|
||||
> [!WARNING]
|
||||
> These parameters are experimental and may change in future versions.
|
||||
|
||||
- `width`: width of the generated image in pixels (default: model-specific, typically 1024)
|
||||
- `height`: height of the generated image in pixels (default: model-specific, typically 1024)
|
||||
- `steps`: number of diffusion steps (default: model-specific)
|
||||
- `seed`: random seed for reproducible image generation (default: random)
|
||||
|
||||
#### Structured outputs
|
||||
|
||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
|
||||
@@ -1867,3 +1878,55 @@ curl http://localhost:11434/api/version
|
||||
"version": "0.5.1"
|
||||
}
|
||||
```
|
||||
|
||||
## Experimental Features
|
||||
|
||||
### Image Generation (Experimental)
|
||||
|
||||
> [!WARNING]
|
||||
> Image generation is experimental and may change in future versions.
|
||||
|
||||
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models (such as Flux). The API automatically detects when an image generation model is being used.
|
||||
|
||||
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`, `seed`) are documented there.
|
||||
|
||||
#### Example
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "flux",
|
||||
"prompt": "a sunset over mountains",
|
||||
"width": 1024,
|
||||
"height": 768
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response (streaming)
|
||||
|
||||
Progress updates during generation:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "flux",
|
||||
"created_at": "2024-01-15T10:30:00.000000Z",
|
||||
"completed": 5,
|
||||
"total": 20,
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
##### Final Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "flux",
|
||||
"created_at": "2024-01-15T10:30:15.000000Z",
|
||||
"images": ["iVBORw0KGgoAAAANSUhEUg..."],
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 15000000000,
|
||||
"load_duration": 2000000000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1468,6 +1468,7 @@ type CompletionRequest struct {
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1518,10 +1519,14 @@ type CompletionResponse struct {
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image generation fields
|
||||
Image []byte `json:"image,omitempty"` // Generated image
|
||||
Step int `json:"step,omitempty"` // Current generation step
|
||||
Total int `json:"total,omitempty"` // Total generation steps
|
||||
// Image contains base64-encoded image data for image generation
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Step is the current step in image generation
|
||||
Step int `json:"step,omitempty"`
|
||||
|
||||
// TotalSteps is the total number of steps for image generation
|
||||
TotalSteps int `json:"total_steps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -546,3 +546,66 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
if err := json.Unmarshal(data, &generateResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Only write response when done with images
|
||||
if generateResponse.Done && len(generateResponse.Images) > 0 {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,3 +961,154 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerationsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image generation basic",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "a beautiful sunset"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "a beautiful sunset",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation with size",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "a beautiful sunset",
|
||||
"size": "512x768"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "a beautiful sunset",
|
||||
Width: 512,
|
||||
Height: 768,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation missing prompt",
|
||||
body: `{
|
||||
"model": "test-model"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation missing model",
|
||||
body: `{
|
||||
"prompt": "a beautiful sunset"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Error.Message != "" {
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match:\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageWriterResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Test that ImageWriter transforms GenerateResponse to OpenAI format
|
||||
endpoint := func(c *gin.Context) {
|
||||
resp := api.GenerateResponse{
|
||||
Model: "test-model",
|
||||
CreatedAt: time.Unix(1234567890, 0).UTC(),
|
||||
Done: true,
|
||||
Images: []string{"dGVzdC1pbWFnZS1kYXRh"}, // base64 of "test-image-data"
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
body := `{"model": "test-model", "prompt": "test"}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var imageResp openai.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if imageResp.Created != 1234567890 {
|
||||
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
|
||||
}
|
||||
|
||||
if len(imageResp.Data) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
|
||||
}
|
||||
|
||||
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
|
||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
@@ -13,114 +14,243 @@ const (
|
||||
Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota
|
||||
Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
Nemotron3NanoCollectingContent
|
||||
Nemotron3NanoCollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
nemotronToolCallClose = "</tool_call>"
|
||||
)
|
||||
|
||||
type Nemotron3NanoParser struct {
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
toolParser *Qwen3CoderParser
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
|
||||
func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
|
||||
|
||||
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.toolParser = &Qwen3CoderParser{}
|
||||
p.toolParser.Init(tools, nil, nil)
|
||||
p.tools = tools
|
||||
|
||||
// thinking is enabled if user requests it
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
if !thinkingEnabled || (prefill && lastMessage.Content != "") {
|
||||
if !thinkingEnabled {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
return tools
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
return tools
|
||||
}
|
||||
|
||||
type nemotronEvent interface {
|
||||
isNemotronEvent()
|
||||
}
|
||||
|
||||
type nemotronEventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type nemotronEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type nemotronEventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (nemotronEventThinkingContent) isNemotronEvent() {}
|
||||
func (nemotronEventContent) isNemotronEvent() {}
|
||||
func (nemotronEventToolCall) isNemotronEvent() {}
|
||||
|
||||
func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
if p.state == Nemotron3NanoCollectingContent {
|
||||
return p.toolParser.Add(s, done)
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case nemotronEventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case nemotronEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case nemotronEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
if p.state == Nemotron3NanoSkipWhitespaceAfterThinking {
|
||||
s = strings.TrimLeftFunc(s, unicode.IsSpace)
|
||||
if s == "" {
|
||||
return "", "", nil, nil
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent {
|
||||
var all []nemotronEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []nemotronEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
|
||||
func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
|
||||
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:]
|
||||
}
|
||||
wsLen := trailingWhitespaceLen(bufStr)
|
||||
return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:]
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) {
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case Nemotron3NanoCollectingThinking:
|
||||
if strings.Contains(bufStr, nemotronThinkClose) {
|
||||
split := strings.SplitN(bufStr, nemotronThinkClose, 2)
|
||||
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.WriteString(remainder)
|
||||
// Transition to whitespace-skipping state if buffer is empty,
|
||||
// otherwise go directly to content collection
|
||||
if remainder == "" {
|
||||
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
}
|
||||
if thinking != "" {
|
||||
return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false
|
||||
}
|
||||
return nil, false
|
||||
|
||||
// We only want to skip whitespace between thinking and content
|
||||
case Nemotron3NanoSkipWhitespaceAfterThinking:
|
||||
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return p.toolParser.Add(s, done)
|
||||
}
|
||||
return nil, true
|
||||
|
||||
// Nemotron3NanoCollectingThinking - buffer and look for end markers
|
||||
p.buffer.WriteString(s)
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Look for end of thinking: </think> or <tool_call> (model may skip </think>)
|
||||
thinkIdx := strings.Index(bufStr, nemotronThinkClose)
|
||||
toolIdx := strings.Index(bufStr, nemotronToolCallOpen)
|
||||
|
||||
var endIdx int = -1
|
||||
var remainder string
|
||||
|
||||
if thinkIdx != -1 && (toolIdx == -1 || thinkIdx < toolIdx) {
|
||||
endIdx = thinkIdx
|
||||
remainder = strings.TrimLeftFunc(bufStr[thinkIdx+len(nemotronThinkClose):], unicode.IsSpace)
|
||||
} else if toolIdx != -1 {
|
||||
endIdx = toolIdx
|
||||
remainder = bufStr[toolIdx:] // Include <tool_call> tag
|
||||
}
|
||||
|
||||
if endIdx != -1 {
|
||||
thinking = strings.TrimRightFunc(bufStr[:endIdx], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
|
||||
if remainder == "" {
|
||||
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
content, _, calls, err = p.toolParser.Add(remainder, done)
|
||||
case Nemotron3NanoCollectingContent:
|
||||
if strings.Contains(bufStr, nemotronToolCallOpen) {
|
||||
split := strings.SplitN(bufStr, nemotronToolCallOpen, 2)
|
||||
content := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(split[1])
|
||||
p.state = Nemotron3NanoCollectingToolCalls
|
||||
if content != "" {
|
||||
return []nemotronEvent{nemotronEventContent{content: content}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
return content, thinking, calls, err
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []nemotronEvent{nemotronEventContent{content: unambig}}, false
|
||||
}
|
||||
return nil, false
|
||||
|
||||
case Nemotron3NanoCollectingToolCalls:
|
||||
if strings.Contains(bufStr, nemotronToolCallClose) {
|
||||
split := strings.SplitN(bufStr, nemotronToolCallClose, 2)
|
||||
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
var events []nemotronEvent
|
||||
if tc, err := p.parseToolCall(split[0]); err == nil {
|
||||
events = append(events, nemotronEventToolCall{toolCall: tc})
|
||||
}
|
||||
|
||||
if !strings.Contains(remaining, nemotronToolCallOpen) {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// No end marker - emit unambiguous thinking
|
||||
thinking = p.emitThinking(bufStr)
|
||||
return "", thinking, nil, nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// emitThinking returns unambiguous thinking content, keeping potential partial tags in buffer
|
||||
func (p *Nemotron3NanoParser) emitThinking(bufStr string) string {
|
||||
// Check for partial </think> or <tool_call> at end
|
||||
thinkOverlap := overlap(bufStr, nemotronThinkClose)
|
||||
toolOverlap := overlap(bufStr, nemotronToolCallOpen)
|
||||
maxOverlap := max(thinkOverlap, toolOverlap)
|
||||
var (
|
||||
nemotronFunctionRegex = regexp.MustCompile(`<function=([^>]+)>`)
|
||||
nemotronParameterRegex = regexp.MustCompile(`<parameter=([^>]+)>\n?([\s\S]*?)\n?</parameter>`)
|
||||
)
|
||||
|
||||
if maxOverlap > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-maxOverlap]
|
||||
unambiguous = strings.TrimRightFunc(unambiguous, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr[len(bufStr)-maxOverlap:])
|
||||
return unambiguous
|
||||
func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
// Extract function name
|
||||
fnMatch := nemotronFunctionRegex.FindStringSubmatch(content)
|
||||
if len(fnMatch) < 2 {
|
||||
return toolCall, nil
|
||||
}
|
||||
toolCall.Function.Name = fnMatch[1]
|
||||
|
||||
// Extract parameters
|
||||
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range paramMatches {
|
||||
if len(match) >= 3 {
|
||||
paramName := match[1]
|
||||
paramValue := strings.TrimSpace(match[2])
|
||||
|
||||
// Try to parse as typed value based on tool definition
|
||||
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
|
||||
}
|
||||
}
|
||||
|
||||
// No partial tags - emit all but trailing whitespace
|
||||
wsLen := trailingWhitespaceLen(bufStr)
|
||||
if wsLen > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-wsLen]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr[len(bufStr)-wsLen:])
|
||||
return unambiguous
|
||||
}
|
||||
|
||||
// Nothing to hold back
|
||||
p.buffer.Reset()
|
||||
return bufStr
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any {
|
||||
// Find the matching tool to get parameter type
|
||||
var paramType api.PropertyType
|
||||
for _, tool := range p.tools {
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
|
||||
paramType = prop.Type
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parseValue(raw, paramType)
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TestNemotron3NanoParser tests Nemotron-specific behavior (thinking support).
|
||||
// Tool call parsing is tested in qwen3coder_test.go since Nemotron delegates to Qwen3CoderParser.
|
||||
func TestNemotron3NanoParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -19,6 +17,18 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "simple content - no thinking",
|
||||
input: "Hello, how can I help you?",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "simple content - thinking disabled",
|
||||
input: "Hello, how can I help you?",
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "thinking then content",
|
||||
input: "Let me think about this...</think>\nHere is my answer.",
|
||||
@@ -33,6 +43,69 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude",
|
||||
expectedContent: "The answer is 42.",
|
||||
},
|
||||
{
|
||||
name: "simple tool call",
|
||||
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "content then tool call",
|
||||
input: "Let me check the weather.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple parameters",
|
||||
input: "<tool_call>\n<function=book_flight>\n<parameter=from>\nSFO\n</parameter>\n<parameter=to>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nSan Francisco\n</parameter>\n</function>\n</tool_call>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nNew York\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking then tool call",
|
||||
input: "I should check the weather...</think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||
@@ -62,6 +135,19 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiline parameter value",
|
||||
input: "<tool_call>\n<function=create_note>\n<parameter=content>\nLine 1\nLine 2\nLine 3\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty thinking block - immediate close",
|
||||
input: "</think>\nHere is my answer.",
|
||||
@@ -75,6 +161,18 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expectedContent: "</think>\nSome content after spurious tag.",
|
||||
},
|
||||
{
|
||||
name: "tool call with no function name - returns empty tool call",
|
||||
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
|
||||
},
|
||||
{
|
||||
name: "content with newlines preserved",
|
||||
input: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||
},
|
||||
{
|
||||
name: "thinking with only whitespace after close tag",
|
||||
input: "My thoughts...</think> \n\t\n Content here.",
|
||||
@@ -82,6 +180,25 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "My thoughts...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
input: "Hello 世界! 🌍 Ñoño",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello 世界! 🌍 Ñoño",
|
||||
},
|
||||
{
|
||||
name: "tool call with numeric parameter",
|
||||
input: "<tool_call>\n<function=set_temp>\n<parameter=value>\n42\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: testArgs(map[string]any{"value": "42"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -116,8 +233,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNemotron3NanoParser_Streaming tests streaming behavior for thinking support.
|
||||
// Tool call streaming is tested in qwen3coder_test.go.
|
||||
func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -127,6 +242,18 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "streaming content character by character",
|
||||
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "streaming content small tokens",
|
||||
chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, how can I help you today?",
|
||||
},
|
||||
{
|
||||
name: "streaming thinking then content - granular",
|
||||
chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."},
|
||||
@@ -141,6 +268,45 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "Step 1: Analyze\nStep 2: Process",
|
||||
expectedContent: "The answer.",
|
||||
},
|
||||
{
|
||||
name: "streaming tool call - highly granular",
|
||||
chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "</", "param", "eter", ">", "\n", "</", "func", "tion", ">", "\n", "</", "tool", "_", "call", ">"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming content then tool call - granular",
|
||||
chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "<function=", "get_weather", ">", "\n", "<parameter=", "city", ">", "\n", "NYC", "\n", "</parameter>", "\n", "</function>", "\n", "</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call tag split character by character",
|
||||
chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking close tag split character by character",
|
||||
chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"},
|
||||
@@ -155,6 +321,22 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "Thinking...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple parameters - streaming",
|
||||
chunks: []string{"<tool_", "call>\n", "<function", "=book_", "flight>", "\n<para", "meter=", "from>\n", "SFO\n", "</param", "eter>", "\n<param", "eter=to", ">\nNYC", "\n</para", "meter>", "\n</func", "tion>\n", "</tool_", "call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking then content then tool call - streaming",
|
||||
chunks: []string{"Ana", "lyzing", " your", " request", "...", "</", "think", ">\n", "I'll", " check", " that", " for", " you", ".", "\n", "<tool", "_call", ">\n", "<function", "=search", ">\n", "<parameter", "=query", ">\n", "test", " query", "\n</", "parameter", ">\n", "</function", ">\n", "</tool", "_call", ">"},
|
||||
@@ -170,6 +352,45 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls - streaming",
|
||||
chunks: []string{
|
||||
"<tool_call>", "\n", "<function=", "get_weather>", "\n",
|
||||
"<parameter=", "city>\n", "San Fran", "cisco\n", "</parameter>", "\n",
|
||||
"</function>", "\n", "</tool_call>", "\n",
|
||||
"<tool_", "call>\n", "<function", "=get_weather", ">\n",
|
||||
"<param", "eter=city", ">\nNew", " York\n", "</parameter>\n",
|
||||
"</function>\n", "</tool_call>",
|
||||
},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiline parameter - streaming",
|
||||
chunks: []string{"<tool_call>\n", "<function=", "create_note>\n", "<parameter=", "content>\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n</parameter>\n", "</function>\n", "</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty thinking block",
|
||||
chunks: []string{"</think>", "\n", "Just content."},
|
||||
@@ -177,6 +398,12 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "",
|
||||
expectedContent: "Just content.",
|
||||
},
|
||||
{
|
||||
name: "empty input chunks interspersed",
|
||||
chunks: []string{"Hello", "", " ", "", "world", "", "!"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello world!",
|
||||
},
|
||||
{
|
||||
name: "tool call immediately after think close - no content",
|
||||
chunks: []string{"Analyzing...", "</think>", "\n", "<tool_call>", "\n<function=test>\n</function>\n", "</tool_call>"},
|
||||
@@ -191,6 +418,25 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with empty parameter value",
|
||||
chunks: []string{"<tool_call>\n<function=test>\n<parameter=name>\n", "\n</parameter>\n</function>\n</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: testArgs(map[string]any{"name": ""}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "partial tool call tag at end - buffered",
|
||||
chunks: []string{"Here's some content", "<tool"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Here's some content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -326,65 +572,3 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNemotron3NanoParser_ToolCallWithoutThinkClose tests the case where thinking is enabled
|
||||
// but the model outputs content + tool call WITHOUT the </think> tag.
|
||||
// The parser should still parse the tool call (content before is treated as thinking).
|
||||
func TestNemotron3NanoParser_ToolCallWithoutThinkClose(t *testing.T) {
|
||||
chunks := []string{
|
||||
"Let", " me", " analyze", " this", ".", "\n",
|
||||
"<tool_call>", "\n",
|
||||
"<function=get_weather>", "\n",
|
||||
"<parameter=city>", "Paris", "</parameter>", "\n",
|
||||
"</function>", "\n",
|
||||
"</tool_call>",
|
||||
}
|
||||
|
||||
p := &Nemotron3NanoParser{}
|
||||
p.Init(nil, nil, &api.ThinkValue{Value: true}) // thinking ENABLED but model doesn't output </think>
|
||||
|
||||
var allContent string
|
||||
var allThinking string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for _, chunk := range chunks {
|
||||
content, thinking, calls, err := p.Add(chunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Drain
|
||||
content, thinking, calls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
|
||||
// The parser was in thinking mode, so text before <tool_call> is emitted as thinking.
|
||||
expectedThinking := "Let me analyze this."
|
||||
|
||||
expectedCalls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if allContent != "" {
|
||||
t.Errorf("expected no content (text was streamed as thinking), got: %q", allContent)
|
||||
}
|
||||
if diff := cmp.Diff(allThinking, expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,37 +91,6 @@ func TestQwenParserStreaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call tags split character by character",
|
||||
steps: []step{
|
||||
{input: "<", wantEvents: []qwenEvent{}},
|
||||
{input: "t", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "_", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: ">", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "b", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "<", wantEvents: []qwenEvent{}},
|
||||
{input: "/", wantEvents: []qwenEvent{}},
|
||||
{input: "t", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "_", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: ">", wantEvents: []qwenEvent{qwenEventRawToolCall{raw: "abc"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call",
|
||||
steps: []step{
|
||||
|
||||
@@ -737,3 +737,57 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
if r.Seed != nil {
|
||||
req.Seed = *r.Seed
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
|
||||
data := make([]ImageURLOrData, 0)
|
||||
for _, img := range resp.Images {
|
||||
data = append(data, ImageURLOrData{B64JSON: img})
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ var (
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errCapabilityImage = errors.New("image generation")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
@@ -76,7 +77,7 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImageGeneration}
|
||||
return []model.Capability{model.CapabilityImage}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
@@ -159,6 +160,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
model.CapabilityImage: errCapabilityImage,
|
||||
}
|
||||
|
||||
for _, cap := range want {
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestModelCapabilities(t *testing.T) {
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
@@ -242,6 +242,24 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
expectedErrMsg: "unknown capability",
|
||||
},
|
||||
{
|
||||
name: "model missing image generation capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
expectedErrMsg: "does not support image generation",
|
||||
},
|
||||
{
|
||||
name: "model with image generation capability",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
145
server/routes.go
145
server/routes.go
@@ -220,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle image generation models
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
||||
return
|
||||
}
|
||||
|
||||
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
||||
return
|
||||
@@ -1096,7 +1102,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
@@ -1202,7 +1208,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
@@ -1594,8 +1600,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// Experimental OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", s.handleImageGeneration)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
@@ -1917,62 +1923,6 @@ func toolCallId() string {
|
||||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGeneration(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size (e.g., "1024x768") into width and height
|
||||
width, height := int32(1024), int32(1024)
|
||||
if req.Size != "" {
|
||||
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var image []byte
|
||||
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
}, func(resp llm.CompletionResponse) {
|
||||
if len(resp.Image) > 0 {
|
||||
image = resp.Image
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
@@ -2522,3 +2472,78 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
// handleImageGenerate handles image generation requests within GenerateHandler.
|
||||
// This is called when the model has the ImageGeneration capability.
|
||||
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
|
||||
// Validate image dimensions
|
||||
const maxDimension int32 = 4096
|
||||
if req.Width > maxDimension || req.Height > maxDimension {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule the runner for image generation
|
||||
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// Handle load-only request (empty prompt)
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers for streaming response
|
||||
c.Header("Content-Type", "application/x-ndjson")
|
||||
|
||||
var streamStarted bool
|
||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: req.Seed,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
streamStarted = true
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: cr.Done,
|
||||
}
|
||||
|
||||
if cr.TotalSteps > 0 {
|
||||
res.Completed = int64(cr.Step)
|
||||
res.Total = int64(cr.TotalSteps)
|
||||
}
|
||||
|
||||
if cr.Image != "" {
|
||||
res.Images = []string{cr.Image}
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.Metrics.TotalDuration = time.Since(checkpointStart)
|
||||
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(res)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
c.Writer.Flush()
|
||||
}); err != nil {
|
||||
// Only send JSON error if streaming hasn't started yet
|
||||
// (once streaming starts, headers are committed and we can't change status code)
|
||||
if !streamStarted {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -571,10 +571,10 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
refCount: 1,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImageGeneration = Capability("image")
|
||||
CapabilityImage = Capability("image")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
@@ -51,6 +51,7 @@ func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
// Hide from main flags section - shown in separate section via AppendFlagsDocs
|
||||
cmd.Flags().MarkHidden("width")
|
||||
cmd.Flags().MarkHidden("height")
|
||||
cmd.Flags().MarkHidden("steps")
|
||||
@@ -58,6 +59,19 @@ func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().MarkHidden("negative")
|
||||
}
|
||||
|
||||
// AppendFlagsDocs appends image generation flags documentation to the command's usage template.
|
||||
func AppendFlagsDocs(cmd *cobra.Command) {
|
||||
usage := `
|
||||
Image Generation Flags (experimental):
|
||||
--width int Image width
|
||||
--height int Image height
|
||||
--steps int Denoising steps
|
||||
--seed int Random seed
|
||||
--negative str Negative prompt
|
||||
`
|
||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + usage)
|
||||
}
|
||||
|
||||
// RunCLI handles the CLI for image generation models.
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
@@ -91,9 +105,7 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
|
||||
}
|
||||
|
||||
// generateImageWithOptions generates an image with the given options.
|
||||
// Note: opts are currently unused as the native API doesn't support size parameters.
|
||||
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -102,7 +114,10 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
Seed: int64(opts.Seed),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -116,32 +131,25 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
var stepBar *progress.StepBar
|
||||
var imageBase64 string
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.Stop()
|
||||
p.StopAndClear()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -179,6 +187,23 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
return err
|
||||
}
|
||||
|
||||
// Preload the model with the specified keepalive
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
preloadReq := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
KeepAlive: keepAlive,
|
||||
}
|
||||
if err := client.Generate(cmd.Context(), preloadReq, func(resp api.GenerateResponse) error {
|
||||
return nil
|
||||
}); err != nil {
|
||||
p.StopAndClear()
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
p.StopAndClear()
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
Placeholder: "Describe an image to generate (/help for commands)",
|
||||
@@ -235,12 +260,10 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
Seed: int64(opts.Seed),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -255,32 +278,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.Stop()
|
||||
p.StopAndClear()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
continue
|
||||
|
||||
@@ -36,6 +36,8 @@ type Response struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
@@ -167,8 +169,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
Seed: req.Seed,
|
||||
Progress: func(step, total int) {
|
||||
resp := Response{
|
||||
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
|
||||
Done: false,
|
||||
Step: step,
|
||||
Total: total,
|
||||
Done: false,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -232,11 +231,13 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
}
|
||||
|
||||
@@ -279,15 +280,11 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
Total: raw.Total,
|
||||
}
|
||||
if raw.Image != "" {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
|
||||
cresp.Image = data
|
||||
}
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
|
||||
Reference in New Issue
Block a user