Compare commits

...

1 Commits

Author SHA1 Message Date
jmorganca
8e56dab90b Add experimental image generation fields to /api/generate
Request fields (experimental):
- width: image width (max 4096)
- height: image height (max 4096)
- steps: denoising steps
- seed: random seed

Response fields (experimental):
- images: base64-encoded generated images
- completed: current step progress
- total: total steps

Other changes:
- Fix lifecycle bug where image models wouldn't unload (refCount issue)
- Fix "headers already written" error on Ctrl+C during streaming
- Add gin middleware for OpenAI /v1/images/generations compatibility
- Update CLI to use /api/generate with progress bar
- Add preload support in interactive mode
2026-01-17 14:08:06 -08:00
16 changed files with 556 additions and 125 deletions

View File

@@ -127,6 +127,25 @@ type GenerateRequest struct {
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Width is the width of the generated image in pixels.
// Only used for image generation models.
Width int32 `json:"width,omitempty"`
// Height is the height of the generated image in pixels.
// Only used for image generation models.
Height int32 `json:"height,omitempty"`
// Steps is the number of diffusion steps for image generation.
// Only used for image generation models.
Steps int32 `json:"steps,omitempty"`
// Seed is the random seed for reproducible image generation.
// If 0 or not specified, a random seed will be used.
// Only used for image generation models.
Seed int64 `json:"seed,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -860,6 +879,20 @@ type GenerateResponse struct {
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Images contains base64-encoded generated images.
// Only present for image generation models.
Images []string `json:"images,omitempty"`
// Completed is the number of completed steps in image generation.
// Only present for image generation models during streaming.
Completed int64 `json:"completed,omitempty"`
// Total is the total number of steps for image generation.
// Only present for image generation models during streaming.
Total int64 `json:"total,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -600,7 +600,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if slices.Contains(info.Capabilities, model.CapabilityImage) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
@@ -1985,6 +1985,7 @@ func NewCLI() *cobra.Command {
} {
switch cmd {
case runCmd:
imagegen.AppendFlagsDocs(cmd)
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{

View File

@@ -1555,7 +1555,7 @@ func TestShowInfoImageGen(t *testing.T) {
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImageGeneration},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
}, false, &b)
if err != nil {

View File

@@ -16,6 +16,7 @@
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Version](#version)
- [Experimental: Image Generation](#image-generation-experimental)
## Conventions
@@ -58,6 +59,16 @@ Advanced parameters (optional):
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
Experimental image generation parameters (for image generation models only):
> [!WARNING]
> These parameters are experimental and may change in future versions.
- `width`: width of the generated image in pixels (default: model-specific, typically 1024)
- `height`: height of the generated image in pixels (default: model-specific, typically 1024)
- `steps`: number of diffusion steps (default: model-specific)
- `seed`: random seed for reproducible image generation (default: random)
#### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
@@ -1867,3 +1878,55 @@ curl http://localhost:11434/api/version
"version": "0.5.1"
}
```
## Experimental Features
### Image Generation (Experimental)
> [!WARNING]
> Image generation is experimental and may change in future versions.
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models (such as Flux). The API automatically detects when an image generation model is being used.
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`, `seed`) are documented there.
#### Example
##### Request
```shell
curl http://localhost:11434/api/generate -d '{
"model": "flux",
"prompt": "a sunset over mountains",
"width": 1024,
"height": 768
}'
```
##### Response (streaming)
Progress updates during generation:
```json
{
"model": "flux",
"created_at": "2024-01-15T10:30:00.000000Z",
"completed": 5,
"total": 20,
"done": false
}
```
##### Final Response
```json
{
"model": "flux",
"created_at": "2024-01-15T10:30:15.000000Z",
"images": ["iVBORw0KGgoAAAANSUhEUg..."],
"done": true,
"done_reason": "stop",
"total_duration": 15000000000,
"load_duration": 2000000000
}
```

View File

@@ -1468,6 +1468,7 @@ type CompletionRequest struct {
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
@@ -1518,10 +1519,14 @@ type CompletionResponse struct {
// Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"`
// Image generation fields
Image []byte `json:"image,omitempty"` // Generated image
Step int `json:"step,omitempty"` // Current generation step
Total int `json:"total,omitempty"` // Total generation steps
// Image contains base64-encoded image data for image generation
Image string `json:"image,omitempty"`
// Step is the current step in image generation
Step int `json:"step,omitempty"`
// TotalSteps is the total number of steps for image generation
TotalSteps int `json:"total_steps,omitempty"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {

View File

@@ -546,3 +546,66 @@ func ResponsesMiddleware() gin.HandlerFunc {
c.Next()
}
}
type ImageWriter struct {
BaseWriter
}
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
if err := json.Unmarshal(data, &generateResponse); err != nil {
return 0, err
}
// Only write response when done with images
if generateResponse.Done && len(generateResponse.Images) > 0 {
w.ResponseWriter.Header().Set("Content-Type", "application/json")
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
}
return len(data), nil
}
func (w *ImageWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ImageGenerationsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageGenerationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}

View File

@@ -961,3 +961,154 @@ func TestRetrieveMiddleware(t *testing.T) {
}
}
}
func TestImageGenerationsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
name: "image generation basic",
body: `{
"model": "test-model",
"prompt": "a beautiful sunset"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "a beautiful sunset",
},
},
{
name: "image generation with size",
body: `{
"model": "test-model",
"prompt": "a beautiful sunset",
"size": "512x768"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "a beautiful sunset",
Width: 512,
Height: 768,
},
},
{
name: "image generation missing prompt",
body: `{
"model": "test-model"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "image generation missing model",
body: `{
"prompt": "a beautiful sunset"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Error.Message != "" {
var errResp openai.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match:\n%s", diff)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match:\n%s", diff)
}
})
}
}
func TestImageWriterResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
// Test that ImageWriter transforms GenerateResponse to OpenAI format
endpoint := func(c *gin.Context) {
resp := api.GenerateResponse{
Model: "test-model",
CreatedAt: time.Unix(1234567890, 0).UTC(),
Done: true,
Images: []string{"dGVzdC1pbWFnZS1kYXRh"}, // base64 of "test-image-data"
}
data, _ := json.Marshal(resp)
c.Writer.Write(append(data, '\n'))
}
router := gin.New()
router.Use(ImageGenerationsMiddleware())
router.Handle(http.MethodPost, "/api/generate", endpoint)
body := `{"model": "test-model", "prompt": "test"}`
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
var imageResp openai.ImageGenerationResponse
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if imageResp.Created != 1234567890 {
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
}
if len(imageResp.Data) != 1 {
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
}
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}

View File

@@ -737,3 +737,57 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Seed *int64 `json:"seed,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageURLOrData `json:"data"`
}
// ImageURLOrData contains either a URL or base64-encoded image data.
type ImageURLOrData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
}
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
req := api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
}
// Parse size if provided (e.g., "1024x768")
if r.Size != "" {
var w, h int32
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
req.Width = w
req.Height = h
}
}
if r.Seed != nil {
req.Seed = *r.Seed
}
return req
}
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
data := make([]ImageURLOrData, 0)
for _, img := range resp.Images {
data = append(data, ImageURLOrData{B64JSON: img})
}
return ImageGenerationResponse{
Created: resp.CreatedAt.Unix(),
Data: data,
}
}

View File

@@ -41,6 +41,7 @@ var (
errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errCapabilityImage = errors.New("image generation")
errInsecureProtocol = errors.New("insecure protocol http")
)
@@ -76,7 +77,7 @@ func (m *Model) Capabilities() []model.Capability {
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration}
return []model.Capability{model.CapabilityImage}
}
// Check for completion capability
@@ -159,6 +160,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
model.CapabilityImage: errCapabilityImage,
}
for _, cap := range want {

View File

@@ -54,7 +54,7 @@ func TestModelCapabilities(t *testing.T) {
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
expectedCaps: []model.Capability{model.CapabilityImage},
},
{
name: "model with completion capability",
@@ -242,6 +242,24 @@ func TestModelCheckCapabilities(t *testing.T) {
checkCaps: []model.Capability{"unknown"},
expectedErrMsg: "unknown capability",
},
{
name: "model missing image generation capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityImage},
expectedErrMsg: "does not support image generation",
},
{
name: "model with image generation capability",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
checkCaps: []model.Capability{model.CapabilityImage},
},
}
for _, tt := range tests {

View File

@@ -220,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
@@ -1096,7 +1102,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
@@ -1202,7 +1208,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
@@ -1594,8 +1600,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// Experimental OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", s.handleImageGeneration)
// OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
@@ -1917,62 +1923,6 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b))
}
func (s *Server) handleImageGeneration(c *gin.Context) {
var req struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
m, err := GetModel(req.Model)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Parse size (e.g., "1024x768") into width and height
width, height := int32(1024), int32(1024)
if req.Size != "" {
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
return
}
}
var image []byte
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: width,
Height: height,
}, func(resp llm.CompletionResponse) {
if len(resp.Image) > 0 {
image = resp.Image
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"created": time.Now().Unix(),
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
})
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
@@ -2522,3 +2472,78 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
}
return msgs
}
// handleImageGenerate handles image generation requests within GenerateHandler.
// This is called when the model has the ImageGeneration capability.
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
// Validate image dimensions
const maxDimension int32 = 4096
if req.Width > maxDimension || req.Height > maxDimension {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
return
}
// Schedule the runner for image generation
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
// Handle load-only request (empty prompt)
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
// Set headers for streaming response
c.Header("Content-Type", "application/x-ndjson")
var streamStarted bool
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: cr.Done,
}
if cr.TotalSteps > 0 {
res.Completed = int64(cr.Step)
res.Total = int64(cr.TotalSteps)
}
if cr.Image != "" {
res.Images = []string{cr.Image}
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.Metrics.TotalDuration = time.Since(checkpointStart)
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
data, _ := json.Marshal(res)
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
}); err != nil {
// Only send JSON error if streaming hasn't started yet
// (once streaming starts, headers are committed and we can't change status code)
if !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
}

View File

@@ -571,10 +571,10 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
totalSize: server.TotalSize(),
vramSize: server.VRAMSize(),
}
s.loadedMu.Lock()

View File

@@ -9,7 +9,7 @@ const (
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image")
CapabilityImage = Capability("image")
)
func (c Capability) String() string {

View File

@@ -51,6 +51,7 @@ func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt")
// Hide from main flags section - shown in separate section via AppendFlagsDocs
cmd.Flags().MarkHidden("width")
cmd.Flags().MarkHidden("height")
cmd.Flags().MarkHidden("steps")
@@ -58,6 +59,19 @@ func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().MarkHidden("negative")
}
// AppendFlagsDocs appends image generation flags documentation to the command's usage template.
func AppendFlagsDocs(cmd *cobra.Command) {
usage := `
Image Generation Flags (experimental):
--width int Image width
--height int Image height
--steps int Denoising steps
--seed int Random seed
--negative str Negative prompt
`
cmd.SetUsageTemplate(cmd.UsageTemplate() + usage)
}
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
@@ -91,9 +105,7 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
}
// generateImageWithOptions generates an image with the given options.
// Note: opts are currently unused as the native API doesn't support size parameters.
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -102,7 +114,10 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
Seed: int64(opts.Seed),
}
if keepAlive != nil {
req.KeepAlive = keepAlive
@@ -116,32 +131,25 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
var stepBar *progress.StepBar
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
// Handle progress updates using structured fields
if resp.Total > 0 {
if stepBar == nil {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
stepBar.Set(int(resp.Completed))
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image data
if resp.Done && len(resp.Images) > 0 {
imageBase64 = resp.Images[0]
}
return nil
})
p.Stop()
p.StopAndClear()
if err != nil {
return err
}
@@ -179,6 +187,23 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
return err
}
// Preload the model with the specified keepalive
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
preloadReq := &api.GenerateRequest{
Model: modelName,
KeepAlive: keepAlive,
}
if err := client.Generate(cmd.Context(), preloadReq, func(resp api.GenerateResponse) error {
return nil
}); err != nil {
p.StopAndClear()
return fmt.Errorf("failed to load model: %w", err)
}
p.StopAndClear()
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Describe an image to generate (/help for commands)",
@@ -235,12 +260,10 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
Seed: int64(opts.Seed),
}
if keepAlive != nil {
req.KeepAlive = keepAlive
@@ -255,32 +278,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
// Handle progress updates using structured fields
if resp.Total > 0 {
if stepBar == nil {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
stepBar.Set(int(resp.Completed))
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image data
if resp.Done && len(resp.Images) > 0 {
imageBase64 = resp.Images[0]
}
return nil
})
p.Stop()
p.StopAndClear()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue

View File

@@ -36,6 +36,8 @@ type Response struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
// Server holds the model and handles requests
@@ -167,8 +169,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
Done: false,
Step: step,
Total: total,
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -232,11 +231,13 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
}
@@ -279,15 +280,11 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
Total: raw.Total,
}
if raw.Image != "" {
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
cresp.Image = data
}
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
}
fn(cresp)