mirror of
https://github.com/ollama/ollama.git
synced 2026-01-17 03:49:12 -05:00
Compare commits
1 Commits
main
...
imagegen-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29837c1b98 |
@@ -281,6 +281,20 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
|
||||
})
|
||||
}
|
||||
|
||||
// XGenerate generates images using the experimental /x/generate endpoint.
|
||||
// This endpoint is specifically designed for image generation models and
|
||||
// supports parameters like width, height, and steps.
|
||||
func (c *Client) XGenerate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/x/generate", req, func(bts []byte) error {
|
||||
var resp GenerateResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn(resp)
|
||||
})
|
||||
}
|
||||
|
||||
// ChatResponseFunc is a function that [Client.Chat] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Chat] will stop generating and return this error.
|
||||
|
||||
21
api/types.go
21
api/types.go
@@ -97,6 +97,15 @@ type GenerateRequest struct {
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Width is the width of the generated image (for image generation models).
|
||||
Width int32 `json:"width,omitempty"`
|
||||
|
||||
// Height is the height of the generated image (for image generation models).
|
||||
Height int32 `json:"height,omitempty"`
|
||||
|
||||
// Steps is the number of diffusion steps (for image generation models).
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]any `json:"options"`
|
||||
@@ -860,6 +869,18 @@ type GenerateResponse struct {
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Status describes the current phase of generation (e.g., "generating image").
|
||||
Status string `json:"status,omitempty"`
|
||||
|
||||
// Total is the total count for the current phase (e.g., total steps).
|
||||
Total int64 `json:"total,omitempty"`
|
||||
|
||||
// Completed is the completed count for the current phase.
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
|
||||
// Images contains base64-encoded generated images for image generation models.
|
||||
Images []string `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
||||
83
docs/api.md
83
docs/api.md
@@ -16,6 +16,7 @@
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Version](#version)
|
||||
- [Experimental: Generate an image](#generate-an-image-experimental)
|
||||
|
||||
## Conventions
|
||||
|
||||
@@ -1867,3 +1868,85 @@ curl http://localhost:11434/api/version
|
||||
"version": "0.5.1"
|
||||
}
|
||||
```
|
||||
|
||||
## Experimental Endpoints
|
||||
|
||||
### Generate an image (Experimental)
|
||||
|
||||
```
|
||||
POST /x/generate
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This endpoint is experimental and may change in future versions.
|
||||
|
||||
Generate an image using an image generation model. This endpoint is specifically designed for diffusion-based image generation models.
|
||||
|
||||
#### Parameters
|
||||
|
||||
- `model`: (required) the [model name](#model-names) of an image generation model
|
||||
- `prompt`: the text prompt describing the image to generate
|
||||
|
||||
Image generation parameters (optional):
|
||||
|
||||
- `width`: width of the generated image in pixels (default: model-specific, typically 1024)
|
||||
- `height`: height of the generated image in pixels (default: model-specific, typically 1024)
|
||||
- `steps`: number of diffusion steps (default: model-specific)
|
||||
|
||||
Other parameters:
|
||||
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
#### Response
|
||||
|
||||
The response is streamed as JSON objects showing generation progress:
|
||||
|
||||
- `status`: describes the current phase (e.g., "generating image")
|
||||
- `total`: total number of steps
|
||||
- `completed`: number of completed steps
|
||||
- `done`: whether generation is complete
|
||||
|
||||
The final response includes:
|
||||
|
||||
- `images`: array of base64-encoded generated images
|
||||
- `total_duration`: time spent generating the image
|
||||
- `load_duration`: time spent loading the model
|
||||
|
||||
#### Examples
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/x/generate -d '{
|
||||
"model": "flux",
|
||||
"prompt": "a sunset over mountains",
|
||||
"width": 1024,
|
||||
"height": 768
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response (streaming)
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "flux",
|
||||
"created_at": "2024-01-15T10:30:00.000000Z",
|
||||
"status": "generating image",
|
||||
"completed": 5,
|
||||
"total": 20,
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
##### Final Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "flux",
|
||||
"created_at": "2024-01-15T10:30:15.000000Z",
|
||||
"images": ["iVBORw0KGgoAAAANSUhEUg..."],
|
||||
"done": true,
|
||||
"total_duration": 15000000000,
|
||||
"load_duration": 2000000000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1468,6 +1468,7 @@ type CompletionRequest struct {
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1518,10 +1519,14 @@ type CompletionResponse struct {
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image generation fields
|
||||
Image []byte `json:"image,omitempty"` // Generated image
|
||||
Step int `json:"step,omitempty"` // Current generation step
|
||||
Total int `json:"total,omitempty"` // Total generation steps
|
||||
// Image contains base64-encoded image data for image generation
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Step is the current step in image generation
|
||||
Step int `json:"step,omitempty"`
|
||||
|
||||
// TotalSteps is the total number of steps for image generation
|
||||
TotalSteps int `json:"total_steps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -50,6 +50,11 @@ type EmbedWriter struct {
|
||||
encodingFormat string
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
done bool
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
@@ -274,6 +279,36 @@ func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Image generation doesn't support streaming in the OpenAI API sense,
|
||||
// so we only write the response when done with images
|
||||
if generateResponse.Done && len(generateResponse.Images) > 0 {
|
||||
w.done = true
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
@@ -393,6 +428,43 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq := openai.FromImageGenerationRequest(req)
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
|
||||
@@ -961,3 +961,143 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerationsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
streamFalse := false
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image generation handler",
|
||||
body: `{
|
||||
"model": "flux",
|
||||
"prompt": "a cat"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "flux",
|
||||
Prompt: "a cat",
|
||||
Stream: &streamFalse,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation with size",
|
||||
body: `{
|
||||
"model": "flux",
|
||||
"prompt": "a dog",
|
||||
"size": "512x512"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "flux",
|
||||
Prompt: "a dog",
|
||||
Stream: &streamFalse,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing prompt error",
|
||||
body: `{
|
||||
"model": "flux"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"prompt": "a cat"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatalf("requests did not match\nExpected: %+v\nActual: %+v", tc.req, *capturedRequest)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatalf("errors did not match\nExpected: %+v\nActual: %+v", tc.err, errResp)
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageWriterIntegration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("transforms generate response to openai format", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware())
|
||||
router.POST("/api/generate", func(c *gin.Context) {
|
||||
// Simulate an image generation response
|
||||
generateResponse := api.GenerateResponse{
|
||||
Done: true,
|
||||
CreatedAt: time.Now(),
|
||||
Images: []string{"base64encodedimage"},
|
||||
}
|
||||
c.JSON(http.StatusOK, generateResponse)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(`{"model":"flux","prompt":"a cat"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var response openai.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(response.Data) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(response.Data))
|
||||
}
|
||||
if response.Data[0].B64JSON != "base64encodedimage" {
|
||||
t.Fatalf("expected image data 'base64encodedimage', got '%s'", response.Data[0].B64JSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -737,3 +737,46 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
|
||||
stream := false
|
||||
return api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Stream: &stream,
|
||||
}
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
|
||||
data := make([]ImageURLOrData, 0)
|
||||
for _, img := range resp.Images {
|
||||
data = append(data, ImageURLOrData{B64JSON: img})
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
162
server/routes.go
162
server/routes.go
@@ -1587,6 +1587,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Experimental image generation
|
||||
r.POST("/x/generate", s.XGenerateHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
@@ -1594,8 +1597,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// Experimental OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", s.handleImageGeneration)
|
||||
// OpenAI-compatible image generation endpoint (uses experimental /x/generate)
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.XGenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
@@ -1908,6 +1911,105 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
// XGenerateHandler handles the experimental /x/generate endpoint for image generation.
|
||||
func (s *Server) XGenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
var req api.GenerateRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Check that this is an image generation model
|
||||
if !slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support image generation", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule the runner
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{model.CapabilityImageGeneration}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// Handle load-only request
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: cr.Done,
|
||||
}
|
||||
|
||||
// Image generation progress
|
||||
if cr.TotalSteps > 0 {
|
||||
res.Status = "generating image"
|
||||
res.Completed = int64(cr.Step)
|
||||
res.Total = int64(cr.TotalSteps)
|
||||
}
|
||||
|
||||
// Final image
|
||||
if cr.Image != "" {
|
||||
res.Images = []string{cr.Image}
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
@@ -1917,62 +2019,6 @@ func toolCallId() string {
|
||||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGeneration(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size (e.g., "1024x768") into width and height
|
||||
width, height := int32(1024), int32(1024)
|
||||
if req.Size != "" {
|
||||
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var image []byte
|
||||
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
}, func(resp llm.CompletionResponse) {
|
||||
if len(resp.Image) > 0 {
|
||||
image = resp.Image
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
|
||||
@@ -91,9 +91,7 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
|
||||
}
|
||||
|
||||
// generateImageWithOptions generates an image with the given options.
|
||||
// Note: opts are currently unused as the native API doesn't support size parameters.
|
||||
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -102,7 +100,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -115,27 +115,20 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
|
||||
var stepBar *progress.StepBar
|
||||
var imageBase64 string
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
err = client.XGenerate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 && resp.Completed > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -235,12 +228,9 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -254,27 +244,20 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
var stepBar *progress.StepBar
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
err = client.XGenerate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 && resp.Completed > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -233,11 +232,13 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
}
|
||||
|
||||
@@ -280,15 +281,11 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
Total: raw.Total,
|
||||
}
|
||||
if raw.Image != "" {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
|
||||
cresp.Image = data
|
||||
}
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
|
||||
Reference in New Issue
Block a user