Files
ollama/x/imagegen/runner.go
2026-02-05 18:25:56 -08:00

204 lines
5.1 KiB
Go

//go:build mlx
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
package imagegen
import (
"context"
"encoding/json"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// Execute is the entry point for the unified MLX runner subprocess.
func Execute(args []string) error {
// Set up logging with appropriate level from environment
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
// Initialize MLX
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
// Detect model type from capabilities
mode := detectModelMode(*modelName)
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
// Create and start server
server, err := newServer(*modelName, *port, mode)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
// LLM-specific endpoints
if mode == ModeLLM {
mux.HandleFunc("/tokenize", server.tokenizeHandler)
mux.HandleFunc("/embedding", server.embeddingHandler)
}
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down mlx runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("mlx runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
// detectModelMode determines whether a model is an LLM or image generation model.
func detectModelMode(modelName string) ModelMode {
// Check for image generation model by looking at model_index.json
modelType := DetectModelType(modelName)
if modelType != "" {
// Known image generation model types
switch modelType {
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
return ModeImageGen
}
}
// Default to LLM mode for safetensors models without known image gen types
return ModeLLM
}
// server holds the model and handles HTTP requests.
type server struct {
mode ModelMode
modelName string
port int
// Image generation model (when mode == ModeImageGen)
imageModel ImageModel
// LLM model (when mode == ModeLLM)
llmModel *llmState
}
// newServer creates a new server instance and loads the appropriate model.
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
s := &server{
mode: mode,
modelName: modelName,
port: port,
}
switch mode {
case ModeImageGen:
if err := s.loadImageModel(); err != nil {
return nil, fmt.Errorf("failed to load image model: %w", err)
}
case ModeLLM:
if err := s.loadLLMModel(); err != nil {
return nil, fmt.Errorf("failed to load LLM model: %w", err)
}
}
return s, nil
}
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := HealthResponse{Status: "ok"}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch s.mode {
case ModeImageGen:
s.handleImageCompletion(w, r, req)
case ModeLLM:
s.handleLLMCompletion(w, r, req)
}
}
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
if s.llmModel == nil {
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
return
}
var req struct {
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
tok := s.llmModel.model.Tokenizer()
tokens := tok.Encode(req.Content, false)
// Convert int32 to int for JSON response
intTokens := make([]int, len(tokens))
for i, t := range tokens {
intTokens[i] = int(t)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
}
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
}