Compare commits

..

1 Commits

Author SHA1 Message Date
Parth Sareen
75d7b5f926 cmd: enable multi-line input and shift enter (#13694) 2026-01-14 17:52:46 -08:00
17 changed files with 170 additions and 181 deletions

View File

@@ -127,10 +127,6 @@ type GenerateRequest struct {
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
// Size specifies the image dimensions for image generation models.
// Format: "WxH" (e.g., "1024x1024"). OpenAI-compatible.
Size string `json:"size,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err

View File

@@ -1464,10 +1464,6 @@ type CompletionRequest struct {
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int
// Size specifies image dimensions for image generation models.
// Format: "WxH" (e.g., "1024x1024"). OpenAI-compatible.
Size string
}
// DoneReason represents the reason why a completion response is done

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"strings"
)
type Prompt struct {
@@ -36,10 +37,11 @@ type Terminal struct {
}
type Instance struct {
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
pastedLines []string
}
func New(prompt Prompt) (*Instance, error) {
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
case CharEsc:
esc = true
case CharInterrupt:
i.pastedLines = nil
i.Prompt.UseAlt = false
return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
case CharForward:
buf.MoveRight()
case CharBackspace, CharCtrlH:
buf.Remove()
if buf.IsEmpty() && len(i.pastedLines) > 0 {
lastIdx := len(i.pastedLines) - 1
prevLine := i.pastedLines[lastIdx]
i.pastedLines = i.pastedLines[:lastIdx]
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
if len(i.pastedLines) == 0 {
fmt.Print(i.Prompt.Prompt)
i.Prompt.UseAlt = false
} else {
fmt.Print(i.Prompt.AltPrompt)
}
for _, r := range prevLine {
buf.Add(r)
}
} else {
buf.Remove()
}
case CharTab:
// todo: convert back to real tabs
for range 8 {
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ:
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter, CharCtrlJ:
case CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
case CharEnter:
output := buf.String()
if len(i.pastedLines) > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
if output != "" {
i.History.Add(output)
}
buf.MoveToEnd()
fmt.Println()
i.Prompt.UseAlt = false
return output, nil
default:

View File

@@ -216,7 +216,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, &req, streamResponse)
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}

View File

@@ -574,6 +574,7 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
}
s.loadedMu.Lock()

View File

@@ -25,14 +25,6 @@ import (
"github.com/ollama/ollama/x/tools"
)
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
@@ -656,7 +648,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err
@@ -707,7 +699,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var sb strings.Builder
var format string
var system string
var multiline MultilineState = MultilineNone
for {
line, err := scanner.Readline()
@@ -721,37 +712,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
scanner.Prompt.UseAlt = false
sb.Reset()
multiline = MultilineNone
continue
case err != nil:
return err
}
switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/clear"):
@@ -860,41 +826,18 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
fmt.Println("Usage: /set system <message>")
continue
}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
system = strings.Join(args[2:], " ")
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
@@ -1081,7 +1024,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line)
}
if sb.Len() > 0 && multiline == MultilineNone {
if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)

View File

@@ -2,8 +2,8 @@ package api
import (
"fmt"
"log/slog"
"net/http"
"strconv"
"strings"
"time"
@@ -50,7 +50,7 @@ func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
req.N = 1
}
if req.Size == "" {
req.Size = fmt.Sprintf("%dx%d", imagegen.DefaultWidth, imagegen.DefaultHeight)
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
@@ -62,8 +62,16 @@ func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, api.Options{}, nil)
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
@@ -73,10 +81,10 @@ func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
return
}
// Build completion request with size (OpenAI format)
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Size: req.Size,
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
@@ -126,6 +134,22 @@ func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
@@ -161,18 +185,20 @@ func buildResponse(imageBase64, format string) ImageGenerationResponse {
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, req *api.GenerateRequest, streamFn func(c *gin.Context, ch chan any)) {
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, api.Options{}, req.KeepAlive)
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request with size (OpenAI format)
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Size: req.Size,
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
@@ -181,14 +207,15 @@ func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, req *api.G
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: req.Model,
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
slog.Error("image generation failed", "model", req.Model, "error", err)
// Log error but don't block - channel is already being consumed
_ = err
}
}()

View File

@@ -37,9 +37,9 @@ type ImageGenOptions struct {
// DefaultOptions returns the default image generation options.
func DefaultOptions() ImageGenOptions {
return ImageGenOptions{
Width: DefaultWidth,
Height: DefaultHeight,
Steps: 0, // 0 means model default
Width: 1024,
Height: 1024,
Steps: 9,
Seed: 0, // 0 means random
}
}
@@ -107,9 +107,9 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
// RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", DefaultWidth, "Image width")
cmd.Flags().Int("height", DefaultHeight, "Image height")
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 9, "Denoising steps")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt")
cmd.Flags().MarkHidden("width")
@@ -158,10 +158,17 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Size: fmt.Sprintf("%dx%d", opts.Width, opts.Height),
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive

View File

@@ -12,7 +12,6 @@ import (
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
@@ -47,9 +46,9 @@ func main() {
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", imagegen.DefaultWidth, "Image width")
height := flag.Int("height", imagegen.DefaultHeight, "Image height")
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
@@ -150,10 +149,10 @@ func main() {
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != imagegen.DefaultWidth {
if *width != 1024 {
editWidth = int32(*width)
}
if *height != imagegen.DefaultHeight {
if *height != 1024 {
editHeight = int32(*height)
}

View File

@@ -1,7 +0,0 @@
package imagegen
// Default image generation parameters.
const (
DefaultWidth = 1024
DefaultHeight = 1024
)

View File

@@ -95,3 +95,8 @@ func EstimateVRAM(modelName string) uint64 {
}
return 21 * GB
}
// HasTensorLayers checks if the given model has tensor layers.
func HasTensorLayers(modelName string) bool {
return ResolveModelName(modelName) != ""
}

View File

@@ -94,6 +94,13 @@ func TestEstimateVRAMDefault(t *testing.T) {
}
}
func TestHasTensorLayers(t *testing.T) {
// Non-existent model should return false
if HasTensorLayers("nonexistent-model") {
t.Error("HasTensorLayers() should return false for non-existent model")
}
}
func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string
result := ResolveModelName("nonexistent-model")

View File

@@ -9,7 +9,6 @@ import (
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -167,10 +166,10 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = imagegen.DefaultWidth
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = imagegen.DefaultHeight
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 30

View File

@@ -188,13 +188,13 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = imagegen.DefaultWidth
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = imagegen.DefaultHeight
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 9 // Z-Image turbo default
cfg.Steps = 9 // Turbo default
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0

View File

@@ -136,12 +136,15 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
// Apply defaults (steps left to model)
// Apply defaults
if req.Width <= 0 {
req.Width = imagegen.DefaultWidth
req.Width = 1024
}
if req.Height <= 0 {
req.Height = imagegen.DefaultHeight
req.Height = 1024
}
if req.Steps <= 0 {
req.Steps = 9
}
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()

View File

@@ -33,12 +33,10 @@ type Server struct {
vramSize uint64
done chan error
client *http.Client
stderrLines []string // Recent stderr lines for error reporting (max 10)
stderrLock sync.Mutex
lastErr string // Last stderr line for error reporting
lastErrLock sync.Mutex
}
const maxStderrLines = 10
// completionRequest is sent to the subprocess
type completionRequest struct {
Prompt string `json:"prompt"`
@@ -141,13 +139,10 @@ func NewServer(modelName string) (*Server, error) {
for scanner.Scan() {
line := scanner.Text()
slog.Warn("image-runner", "msg", line)
// Capture recent stderr lines for error reporting
s.stderrLock.Lock()
s.stderrLines = append(s.stderrLines, line)
if len(s.stderrLines) > maxStderrLines {
s.stderrLines = s.stderrLines[1:]
}
s.stderrLock.Unlock()
// Capture last error line for better error reporting
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
@@ -176,9 +171,7 @@ func (s *Server) ModelPath() string {
return s.modelName
}
// Load is a no-op for image generation models.
// Unlike LLM models, imagegen models are loaded by the subprocess at startup
// rather than through this interface method.
// Load is called by the scheduler after the server is created.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
@@ -211,16 +204,20 @@ func (s *Server) waitUntilRunning() error {
for {
select {
case err := <-s.done:
// Include recent stderr lines for better error context
stderrContext := s.getStderrContext()
if stderrContext != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", stderrContext, err)
// Include last stderr line for better error context
s.lastErrLock.Lock()
lastErr := s.lastErr
s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
}
return fmt.Errorf("image runner exited unexpectedly: %w", err)
case <-timeout:
stderrContext := s.getStderrContext()
if stderrContext != "" {
return fmt.Errorf("timeout waiting for image runner: %s", stderrContext)
s.lastErrLock.Lock()
lastErr := s.lastErr
s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
}
return errors.New("timeout waiting for image runner to start")
case <-ticker.C:
@@ -232,38 +229,34 @@ func (s *Server) waitUntilRunning() error {
}
}
// getStderrContext returns recent stderr lines joined as a single string.
func (s *Server) getStderrContext() string {
s.stderrLock.Lock()
defer s.stderrLock.Unlock()
if len(s.stderrLines) == 0 {
return ""
}
return strings.Join(s.stderrLines, "; ")
}
// WaitUntilRunning is a no-op for image generation models.
// NewServer already blocks until the subprocess is ready, so this method
// returns immediately. Required by the llm.LlamaServer interface.
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
func (s *Server) WaitUntilRunning(ctx context.Context) error {
return nil
}
// Completion generates an image from the prompt via the subprocess.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Build request with defaults (steps left to model)
// Build request
creq := completionRequest{
Prompt: req.Prompt,
Width: DefaultWidth,
Height: DefaultHeight,
Width: 1024,
Height: 1024,
Steps: 9,
Seed: time.Now().UnixNano(),
}
// Parse size string (OpenAI format: "WxH")
if req.Size != "" {
if w, h := parseSize(req.Size); w > 0 && h > 0 {
creq.Width = w
creq.Height = h
if req.Options != nil {
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
creq.Width = int32(req.Options.NumCtx)
}
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
creq.Height = int32(req.Options.NumGPU)
}
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
creq.Steps = req.Options.NumPredict
}
if req.Options.Seed > 0 {
creq.Seed = int64(req.Options.Seed)
}
}
@@ -353,20 +346,17 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
// Embedding returns an error as image generation models don't produce embeddings.
// Required by the llm.LlamaServer interface.
// Embedding is not supported for image generation models.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("embedding not supported for image generation models")
}
// Tokenize returns an error as image generation uses internal tokenization.
// Required by the llm.LlamaServer interface.
// Tokenize is not supported for image generation models.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
return nil, errors.New("tokenize not supported for image generation models")
}
// Detokenize returns an error as image generation uses internal tokenization.
// Required by the llm.LlamaServer interface.
// Detokenize is not supported for image generation models.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("detokenize not supported for image generation models")
}
@@ -386,8 +376,7 @@ func (s *Server) GetPort() int {
return s.port
}
// GetDeviceInfos returns nil as GPU tracking is handled by the subprocess.
// Required by the llm.LlamaServer interface.
// GetDeviceInfos returns nil since we don't track GPU info.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
@@ -404,14 +393,3 @@ func (s *Server) HasExited() bool {
// Ensure Server implements llm.LlamaServer
var _ llm.LlamaServer = (*Server)(nil)
// parseSize parses an OpenAI-style size string "WxH" into width and height.
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 0, 0
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
return int32(w), int32(h)
}