Compare commits

...

1 Commits

Author SHA1 Message Date
jmorganca
29837c1b98 Add experimental /x/generate endpoint for image generation
This adds a new experimental endpoint /x/generate specifically for image
generation models, keeping the main /api/generate endpoint unchanged.

New endpoint:
- POST /x/generate - experimental image generation endpoint
- Supports width, height, steps parameters
- Returns progress updates and base64-encoded images
- Validates that the model supports image generation

API changes:
- Add width, height, steps parameters to GenerateRequest
- Add status, total, completed, images fields to GenerateResponse
- Add XGenerate method to api.Client for calling /x/generate

OpenAI compatibility:
- /v1/images/generations now routes through /x/generate
- Uses middleware pattern like other OpenAI endpoints
- Returns OpenAI-compatible response format with b64_json data

CLI:
- imagegen CLI now uses /x/generate via client.XGenerate()
- Supports --width, --height, --steps flags

Internal changes:
- Add XGenerateHandler to server/routes.go
- Update llm.CompletionRequest/Response with image generation fields
- Change Image field from []byte to string (base64-encoded)
- Add Steps field to CompletionRequest
- Rename Total to TotalSteps for clarity
2026-01-16 21:50:19 -08:00
10 changed files with 518 additions and 114 deletions

View File

@@ -281,6 +281,20 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
})
}
// XGenerate generates images using the experimental /x/generate endpoint.
// This endpoint is specifically designed for image generation models and
// supports parameters like width, height, and steps.
func (c *Client) XGenerate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/x/generate", req, func(bts []byte) error {
var resp GenerateResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
// ChatResponseFunc is a function that [Client.Chat] invokes every time
// a response is received from the service. If this function returns an error,
// [Client.Chat] will stop generating and return this error.

View File

@@ -97,6 +97,15 @@ type GenerateRequest struct {
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
// Width is the width of the generated image (for image generation models).
Width int32 `json:"width,omitempty"`
// Height is the height of the generated image (for image generation models).
Height int32 `json:"height,omitempty"`
// Steps is the number of diffusion steps (for image generation models).
Steps int32 `json:"steps,omitempty"`
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]any `json:"options"`
@@ -860,6 +869,18 @@ type GenerateResponse struct {
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
// Status describes the current phase of generation (e.g., "generating image").
Status string `json:"status,omitempty"`
// Total is the total count for the current phase (e.g., total steps).
Total int64 `json:"total,omitempty"`
// Completed is the completed count for the current phase.
Completed int64 `json:"completed,omitempty"`
// Images contains base64-encoded generated images for image generation models.
Images []string `json:"images,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -16,6 +16,7 @@
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Version](#version)
- [Experimental: Generate an image](#generate-an-image-experimental)
## Conventions
@@ -1867,3 +1868,85 @@ curl http://localhost:11434/api/version
"version": "0.5.1"
}
```
## Experimental Endpoints
### Generate an image (Experimental)
```
POST /x/generate
```
> [!WARNING]
> This endpoint is experimental and may change in future versions.
Generate an image using an image generation model. This endpoint is specifically designed for diffusion-based image generation models.
#### Parameters
- `model`: (required) the [model name](#model-names) of an image generation model
- `prompt`: the text prompt describing the image to generate
Image generation parameters (optional):
- `width`: width of the generated image in pixels (default: model-specific, typically 1024)
- `height`: height of the generated image in pixels (default: model-specific, typically 1024)
- `steps`: number of diffusion steps (default: model-specific)
Other parameters:
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
#### Response
The response is streamed as JSON objects showing generation progress:
- `status`: describes the current phase (e.g., "generating image")
- `total`: total number of steps
- `completed`: number of completed steps
- `done`: whether generation is complete
The final response includes:
- `images`: array of base64-encoded generated images
- `total_duration`: time spent generating the image
- `load_duration`: time spent loading the model
#### Examples
##### Request
```shell
curl http://localhost:11434/x/generate -d '{
"model": "flux",
"prompt": "a sunset over mountains",
"width": 1024,
"height": 768
}'
```
##### Response (streaming)
```json
{
"model": "flux",
"created_at": "2024-01-15T10:30:00.000000Z",
"status": "generating image",
"completed": 5,
"total": 20,
"done": false
}
```
##### Final Response
```json
{
"model": "flux",
"created_at": "2024-01-15T10:30:15.000000Z",
"images": ["iVBORw0KGgoAAAANSUhEUg..."],
"done": true,
"total_duration": 15000000000,
"load_duration": 2000000000
}
```

View File

@@ -1468,6 +1468,7 @@ type CompletionRequest struct {
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
@@ -1518,10 +1519,14 @@ type CompletionResponse struct {
// Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"`
// Image generation fields
Image []byte `json:"image,omitempty"` // Generated image
Step int `json:"step,omitempty"` // Current generation step
Total int `json:"total,omitempty"` // Total generation steps
// Image contains base64-encoded image data for image generation
Image string `json:"image,omitempty"`
// Step is the current step in image generation
Step int `json:"step,omitempty"`
// TotalSteps is the total number of steps for image generation
TotalSteps int `json:"total_steps,omitempty"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {

View File

@@ -50,6 +50,11 @@ type EmbedWriter struct {
encodingFormat string
}
type ImageWriter struct {
BaseWriter
done bool
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
@@ -274,6 +279,36 @@ func (w *EmbedWriter) Write(data []byte) (int, error) {
return w.writeResponse(data)
}
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// Image generation doesn't support streaming in the OpenAI API sense,
// so we only write the response when done with images
if generateResponse.Done && len(generateResponse.Images) > 0 {
w.done = true
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
if err != nil {
return 0, err
}
}
return len(data), nil
}
func (w *ImageWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
@@ -393,6 +428,43 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
}
}
func ImageGenerationsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageGenerationRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
var b bytes.Buffer
genReq := openai.FromImageGenerationRequest(req)
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ChatCompletionRequest

View File

@@ -961,3 +961,143 @@ func TestRetrieveMiddleware(t *testing.T) {
}
}
}
func TestImageGenerationsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
streamFalse := false
testCases := []testCase{
{
name: "image generation handler",
body: `{
"model": "flux",
"prompt": "a cat"
}`,
req: api.GenerateRequest{
Model: "flux",
Prompt: "a cat",
Stream: &streamFalse,
},
},
{
name: "image generation with size",
body: `{
"model": "flux",
"prompt": "a dog",
"size": "512x512"
}`,
req: api.GenerateRequest{
Model: "flux",
Prompt: "a dog",
Stream: &streamFalse,
},
},
{
name: "missing prompt error",
body: `{
"model": "flux"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "missing model error",
body: `{
"prompt": "a cat"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var errResp openai.ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatalf("requests did not match\nExpected: %+v\nActual: %+v", tc.req, *capturedRequest)
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatalf("errors did not match\nExpected: %+v\nActual: %+v", tc.err, errResp)
}
capturedRequest = nil
})
}
}
func TestImageWriterIntegration(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("transforms generate response to openai format", func(t *testing.T) {
router := gin.New()
router.Use(ImageGenerationsMiddleware())
router.POST("/api/generate", func(c *gin.Context) {
// Simulate an image generation response
generateResponse := api.GenerateResponse{
Done: true,
CreatedAt: time.Now(),
Images: []string{"base64encodedimage"},
}
c.JSON(http.StatusOK, generateResponse)
})
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(`{"model":"flux","prompt":"a cat"}`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
var response openai.ImageGenerationResponse
if err := json.Unmarshal(resp.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if len(response.Data) != 1 {
t.Fatalf("expected 1 image, got %d", len(response.Data))
}
if response.Data[0].B64JSON != "base64encodedimage" {
t.Fatalf("expected image data 'base64encodedimage', got '%s'", response.Data[0].B64JSON)
}
})
}

View File

@@ -737,3 +737,46 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageURLOrData `json:"data"`
}
// ImageURLOrData contains either a URL or base64-encoded image data.
type ImageURLOrData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
}
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
stream := false
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Stream: &stream,
}
}
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
data := make([]ImageURLOrData, 0)
for _, img := range resp.Images {
data = append(data, ImageURLOrData{B64JSON: img})
}
return ImageGenerationResponse{
Created: resp.CreatedAt.Unix(),
Data: data,
}
}

View File

@@ -1587,6 +1587,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Experimental image generation
r.POST("/x/generate", s.XGenerateHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
@@ -1594,8 +1597,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Experimental OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", s.handleImageGeneration)
// OpenAI-compatible image generation endpoint (uses experimental /x/generate)
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.XGenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
@@ -1908,6 +1911,105 @@ func (s *Server) PsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
// XGenerateHandler handles the experimental /x/generate endpoint for image generation.
func (s *Server) XGenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
m, err := GetModel(name.String())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Check that this is an image generation model
if !slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support image generation", req.Model)})
return
}
// Schedule the runner
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{model.CapabilityImageGeneration}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
// Handle load-only request
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
ch := make(chan any)
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: cr.Done,
}
// Image generation progress
if cr.TotalSteps > 0 {
res.Status = "generating image"
res.Completed = int64(cr.Step)
res.Total = int64(cr.TotalSteps)
}
// Final image
if cr.Image != "" {
res.Images = []string{cr.Image}
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
streamResponse(c, ch)
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
@@ -1917,62 +2019,6 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b))
}
func (s *Server) handleImageGeneration(c *gin.Context) {
var req struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
m, err := GetModel(req.Model)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Parse size (e.g., "1024x768") into width and height
width, height := int32(1024), int32(1024)
if req.Size != "" {
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
return
}
}
var image []byte
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: width,
Height: height,
}, func(resp llm.CompletionResponse) {
if len(resp.Image) > 0 {
image = resp.Image
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"created": time.Now().Unix(),
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
})
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()

View File

@@ -91,9 +91,7 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
}
// generateImageWithOptions generates an image with the given options.
// Note: opts are currently unused as the native API doesn't support size parameters.
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -102,7 +100,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
}
if keepAlive != nil {
req.KeepAlive = keepAlive
@@ -115,27 +115,20 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
var stepBar *progress.StepBar
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
err = client.XGenerate(cmd.Context(), req, func(resp api.GenerateResponse) error {
// Handle progress updates using structured fields
if resp.Total > 0 && resp.Completed > 0 {
if stepBar == nil {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
stepBar.Set(int(resp.Completed))
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image data
if resp.Done && len(resp.Images) > 0 {
imageBase64 = resp.Images[0]
}
return nil
@@ -235,12 +228,9 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
}
if keepAlive != nil {
req.KeepAlive = keepAlive
@@ -254,27 +244,20 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
var stepBar *progress.StepBar
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
err = client.XGenerate(cmd.Context(), req, func(resp api.GenerateResponse) error {
// Handle progress updates using structured fields
if resp.Total > 0 && resp.Completed > 0 {
if stepBar == nil {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
stepBar.Set(int(resp.Completed))
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image data
if resp.Done && len(resp.Images) > 0 {
imageBase64 = resp.Images[0]
}
return nil

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -233,11 +232,13 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
}
@@ -280,15 +281,11 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
Total: raw.Total,
}
if raw.Image != "" {
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
cresp.Image = data
}
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
}
fn(cresp)