mirror of
https://github.com/ollama/ollama.git
synced 2026-01-15 10:58:36 -05:00
Compare commits
7 Commits
main
...
imagegen-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1c7d8718f | ||
|
|
9d07e26b62 | ||
|
|
8485b6546e | ||
|
|
d9ae425d54 | ||
|
|
9e1d79ac67 | ||
|
|
7273d9925e | ||
|
|
4896240528 |
@@ -127,6 +127,10 @@ type GenerateRequest struct {
|
||||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// Size specifies the image dimensions for image generation models.
|
||||
// Format: "WxH" (e.g., "1024x1024"). OpenAI-compatible.
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
|
||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1464,6 +1464,10 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Size specifies image dimensions for image generation models.
|
||||
// Format: "WxH" (e.g., "1024x1024"). OpenAI-compatible.
|
||||
Size string
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -37,11 +36,10 @@ type Terminal struct {
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
pastedLines []string
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
@@ -176,8 +174,6 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
i.pastedLines = nil
|
||||
i.Prompt.UseAlt = false
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
@@ -192,23 +188,7 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
||||
lastIdx := len(i.pastedLines) - 1
|
||||
prevLine := i.pastedLines[lastIdx]
|
||||
i.pastedLines = i.pastedLines[:lastIdx]
|
||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
||||
if len(i.pastedLines) == 0 {
|
||||
fmt.Print(i.Prompt.Prompt)
|
||||
i.Prompt.UseAlt = false
|
||||
} else {
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
}
|
||||
for _, r := range prevLine {
|
||||
buf.Add(r)
|
||||
}
|
||||
} else {
|
||||
buf.Remove()
|
||||
}
|
||||
buf.Remove()
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
@@ -231,28 +211,13 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
case CharEnter:
|
||||
case CharEnter, CharCtrlJ:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
||||
i.pastedLines = nil
|
||||
}
|
||||
if output != "" {
|
||||
i.History.Add(output)
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
i.Prompt.UseAlt = false
|
||||
|
||||
return output, nil
|
||||
default:
|
||||
|
||||
@@ -216,7 +216,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
imagegenapi.HandleGenerateRequest(c, s, &req, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -574,7 +574,6 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
refCount: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
|
||||
67
x/cmd/run.go
67
x/cmd/run.go
@@ -25,6 +25,14 @@ import (
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// MultilineState tracks the state of multiline input
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilineSystem
|
||||
)
|
||||
|
||||
// Tool output capping constants
|
||||
const (
|
||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||
@@ -648,7 +656,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -699,6 +707,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
var multiline MultilineState = MultilineNone
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -712,12 +721,37 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
multiline = MultilineNone
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case multiline != MultilineNone:
|
||||
// check if there's a multiline terminating string
|
||||
before, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
continue
|
||||
}
|
||||
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
scanner.Prompt.UseAlt = false
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
@@ -826,18 +860,41 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message>")
|
||||
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
|
||||
continue
|
||||
}
|
||||
|
||||
system = strings.Join(args[2:], " ")
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
multiline = MultilineSystem
|
||||
|
||||
line := strings.Join(args[2:], " ")
|
||||
line, ok := strings.CutPrefix(line, `"""`)
|
||||
if !ok {
|
||||
multiline = MultilineNone
|
||||
} else {
|
||||
// only cut suffix if the line is multiline
|
||||
line, ok = strings.CutSuffix(line, `"""`)
|
||||
if ok {
|
||||
multiline = MultilineNone
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
if multiline != MultilineNone {
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
@@ -1024,7 +1081,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 {
|
||||
if sb.Len() > 0 && multiline == MultilineNone {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -50,7 +50,7 @@ func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
req.Size = fmt.Sprintf("%dx%d", imagegen.DefaultWidth, imagegen.DefaultHeight)
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
@@ -62,16 +62,8 @@ func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, api.Options{}, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
@@ -81,10 +73,10 @@ func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
// Build completion request with size (OpenAI format)
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
Prompt: req.Prompt,
|
||||
Size: req.Size,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
@@ -134,22 +126,6 @@ func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.
|
||||
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractBase64(content string) string {
|
||||
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
return content[13:]
|
||||
@@ -185,20 +161,18 @@ func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, req *api.GenerateRequest, streamFn func(c *gin.Context, ch chan any)) {
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, api.Options{}, req.KeepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
// Build completion request with size (OpenAI format)
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
Prompt: req.Prompt,
|
||||
Size: req.Size,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
@@ -207,15 +181,14 @@ func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName,
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
slog.Error("image generation failed", "model", req.Model, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -37,9 +37,9 @@ type ImageGenOptions struct {
|
||||
// DefaultOptions returns the default image generation options.
|
||||
func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Width: DefaultWidth,
|
||||
Height: DefaultHeight,
|
||||
Steps: 0, // 0 means model default
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
@@ -107,9 +107,9 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("width", DefaultWidth, "Image width")
|
||||
cmd.Flags().Int("height", DefaultHeight, "Image height")
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
cmd.Flags().MarkHidden("width")
|
||||
@@ -158,17 +158,10 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Size: fmt.Sprintf("%dx%d", opts.Width, opts.Height),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
@@ -46,9 +47,9 @@ func main() {
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
width := flag.Int("width", imagegen.DefaultWidth, "Image width")
|
||||
height := flag.Int("height", imagegen.DefaultHeight, "Image height")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
@@ -149,10 +150,10 @@ func main() {
|
||||
// unless explicitly overridden from defaults
|
||||
editWidth := int32(0)
|
||||
editHeight := int32(0)
|
||||
if *width != 1024 {
|
||||
if *width != imagegen.DefaultWidth {
|
||||
editWidth = int32(*width)
|
||||
}
|
||||
if *height != 1024 {
|
||||
if *height != imagegen.DefaultHeight {
|
||||
editHeight = int32(*height)
|
||||
}
|
||||
|
||||
|
||||
7
x/imagegen/defaults.go
Normal file
7
x/imagegen/defaults.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package imagegen
|
||||
|
||||
// Default image generation parameters.
|
||||
const (
|
||||
DefaultWidth = 1024
|
||||
DefaultHeight = 1024
|
||||
)
|
||||
@@ -95,8 +95,3 @@ func EstimateVRAM(modelName string) uint64 {
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
}
|
||||
|
||||
@@ -94,13 +94,6 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -166,10 +167,10 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
cfg.Width = imagegen.DefaultWidth
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
cfg.Height = imagegen.DefaultHeight
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 30
|
||||
|
||||
@@ -188,13 +188,13 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
cfg.Width = imagegen.DefaultWidth
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
cfg.Height = imagegen.DefaultHeight
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Turbo default
|
||||
cfg.Steps = 9 // Z-Image turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
|
||||
@@ -136,15 +136,12 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Apply defaults
|
||||
// Apply defaults (steps left to model)
|
||||
if req.Width <= 0 {
|
||||
req.Width = 1024
|
||||
req.Width = imagegen.DefaultWidth
|
||||
}
|
||||
if req.Height <= 0 {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
req.Steps = 9
|
||||
req.Height = imagegen.DefaultHeight
|
||||
}
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
|
||||
@@ -33,10 +33,12 @@ type Server struct {
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string // Last stderr line for error reporting
|
||||
lastErrLock sync.Mutex
|
||||
stderrLines []string // Recent stderr lines for error reporting (max 10)
|
||||
stderrLock sync.Mutex
|
||||
}
|
||||
|
||||
const maxStderrLines = 10
|
||||
|
||||
// completionRequest is sent to the subprocess
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -139,10 +141,13 @@ func NewServer(modelName string) (*Server, error) {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
// Capture last error line for better error reporting
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
// Capture recent stderr lines for error reporting
|
||||
s.stderrLock.Lock()
|
||||
s.stderrLines = append(s.stderrLines, line)
|
||||
if len(s.stderrLines) > maxStderrLines {
|
||||
s.stderrLines = s.stderrLines[1:]
|
||||
}
|
||||
s.stderrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -171,7 +176,9 @@ func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load is called by the scheduler after the server is created.
|
||||
// Load is a no-op for image generation models.
|
||||
// Unlike LLM models, imagegen models are loaded by the subprocess at startup
|
||||
// rather than through this interface method.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -204,20 +211,16 @@ func (s *Server) waitUntilRunning() error {
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
// Include last stderr line for better error context
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
|
||||
// Include recent stderr lines for better error context
|
||||
stderrContext := s.getStderrContext()
|
||||
if stderrContext != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", stderrContext, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
s.lastErrLock.Lock()
|
||||
lastErr := s.lastErr
|
||||
s.lastErrLock.Unlock()
|
||||
if lastErr != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
|
||||
stderrContext := s.getStderrContext()
|
||||
if stderrContext != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", stderrContext)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
case <-ticker.C:
|
||||
@@ -229,34 +232,38 @@ func (s *Server) waitUntilRunning() error {
|
||||
}
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
|
||||
// getStderrContext returns recent stderr lines joined as a single string.
|
||||
func (s *Server) getStderrContext() string {
|
||||
s.stderrLock.Lock()
|
||||
defer s.stderrLock.Unlock()
|
||||
if len(s.stderrLines) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(s.stderrLines, "; ")
|
||||
}
|
||||
|
||||
// WaitUntilRunning is a no-op for image generation models.
|
||||
// NewServer already blocks until the subprocess is ready, so this method
|
||||
// returns immediately. Required by the llm.LlamaServer interface.
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Completion generates an image from the prompt via the subprocess.
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Build request
|
||||
// Build request with defaults (steps left to model)
|
||||
creq := completionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 9,
|
||||
Width: DefaultWidth,
|
||||
Height: DefaultHeight,
|
||||
Seed: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
if req.Options != nil {
|
||||
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
|
||||
creq.Width = int32(req.Options.NumCtx)
|
||||
}
|
||||
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
|
||||
creq.Height = int32(req.Options.NumGPU)
|
||||
}
|
||||
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
|
||||
creq.Steps = req.Options.NumPredict
|
||||
}
|
||||
if req.Options.Seed > 0 {
|
||||
creq.Seed = int64(req.Options.Seed)
|
||||
// Parse size string (OpenAI format: "WxH")
|
||||
if req.Size != "" {
|
||||
if w, h := parseSize(req.Size); w > 0 && h > 0 {
|
||||
creq.Width = w
|
||||
creq.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,17 +353,20 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding is not supported for image generation models.
|
||||
// Embedding returns an error as image generation models don't produce embeddings.
|
||||
// Required by the llm.LlamaServer interface.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("embedding not supported for image generation models")
|
||||
}
|
||||
|
||||
// Tokenize is not supported for image generation models.
|
||||
// Tokenize returns an error as image generation uses internal tokenization.
|
||||
// Required by the llm.LlamaServer interface.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("tokenize not supported for image generation models")
|
||||
}
|
||||
|
||||
// Detokenize is not supported for image generation models.
|
||||
// Detokenize returns an error as image generation uses internal tokenization.
|
||||
// Required by the llm.LlamaServer interface.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("detokenize not supported for image generation models")
|
||||
}
|
||||
@@ -376,7 +386,8 @@ func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
|
||||
// GetDeviceInfos returns nil since we don't track GPU info.
|
||||
// GetDeviceInfos returns nil as GPU tracking is handled by the subprocess.
|
||||
// Required by the llm.LlamaServer interface.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
@@ -393,3 +404,14 @@ func (s *Server) HasExited() bool {
|
||||
|
||||
// Ensure Server implements llm.LlamaServer
|
||||
var _ llm.LlamaServer = (*Server)(nil)
|
||||
|
||||
// parseSize parses an OpenAI-style size string "WxH" into width and height.
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 0, 0
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user