Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
87cb080a91 support other modelfile commands for image generation models 2026-01-10 12:39:44 -08:00
48 changed files with 473 additions and 2350 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs id: logs
attributes: attributes:
label: Relevant log output label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details. description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
render: shell render: shell
validations: validations:
required: false required: false

View File

@@ -372,17 +372,13 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }} outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: | - run: |
for COMPONENT in bin/* lib/ollama/*; do for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in case "$COMPONENT" in
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@@ -48,10 +48,9 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON) set(GGML_CPU_ALL_VARIANTS ON)
endif() endif()
if(APPLE) if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
set(CMAKE_BUILD_RPATH "@loader_path") set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path") set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif() endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama) set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -197,14 +196,6 @@ if(MLX_ENGINE)
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
) )
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies # Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND) if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS file(GLOB CUDART_LIBS

View File

@@ -161,9 +161,6 @@ ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1 ENV CGO_ENABLED=1
ARG CGO_CFLAGS ARG CGO_CFLAGS
ARG CGO_CXXFLAGS ARG CGO_CXXFLAGS
RUN mkdir -p dist/bin
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama
@@ -185,7 +182,6 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/ COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/ COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64 FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil return nil
} }
const maxBufferSize = 8 * format.MegaByte const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader var buf io.Reader

View File

@@ -100,8 +100,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if filename == "" { if filename == "" {
// No Modelfile found - check if current directory is an image gen model // No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") { if imagegen.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize") return imagegenclient.CreateModel(args[0], ".", p)
return imagegenclient.CreateModel(args[0], ".", quantize, p)
} }
reader = strings.NewReader("FROM .\n") reader = strings.NewReader("FROM .\n")
} else { } else {
@@ -124,6 +123,21 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
// Check if FROM points to an imagegen model directory
for _, mfCmd := range modelfile.Commands {
if mfCmd.Name == "model" {
// Resolve the path relative to the Modelfile directory
fromPath := mfCmd.Args
if !filepath.IsAbs(fromPath) {
fromPath = filepath.Join(filepath.Dir(filename), fromPath)
}
if imagegen.IsTensorModelDir(fromPath) {
return imagegenclient.CreateModelFromModelfile(args[0], fromPath, modelfile.Commands, p)
}
break
}
}
status := "gathering model components" status := "gathering model components"
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
@@ -465,6 +479,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0] name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) { info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name} showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq) info, err := client.Show(cmd.Context(), showReq)
@@ -526,18 +548,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions) return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
} }
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag // Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental") isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo") yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
if interactive { if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -567,7 +580,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Use experimental agent loop with tools // Use experimental agent loop with tools
if isExperimental { if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch) return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
} }
return generateInteractive(cmd, opts) return generateInteractive(cmd, opts)
@@ -673,11 +686,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest] bar, ok := bars[resp.Digest]
if !ok { if !ok {
msg := resp.Status bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar bars[resp.Digest] = bar
p.Add(resp.Digest, bar) p.Add(resp.Digest, bar)
} }
@@ -843,6 +852,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
} }
func ShowHandler(cmd *cobra.Command, args []string) error { func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
@@ -1787,7 +1801,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)") runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools") runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)") runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
// Image generation flags (width, height, steps, seed, etc.) // Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd) imagegen.RegisterFlags(runCmd)

View File

@@ -1547,79 +1547,6 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
} }
} }
func TestShowInfoImageGen(t *testing.T) {
var b bytes.Buffer
err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImageGeneration},
Requires: "0.14.0",
}, false, &b)
if err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
}
func TestPushProgressMessage(t *testing.T) {
tests := []struct {
name string
status string
digest string
wantMsg string
}{
{
name: "uses status when provided",
status: "uploading model",
digest: "sha256:abc123456789def",
wantMsg: "uploading model",
},
{
name: "falls back to digest when status empty",
status: "",
digest: "sha256:abc123456789def",
wantMsg: "pushing abc123456789...",
},
{
name: "handles short digest gracefully",
status: "",
digest: "sha256:abc",
wantMsg: "pushing sha256:abc...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := tt.status
if msg == "" {
if len(tt.digest) >= 19 {
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
} else {
msg = fmt.Sprintf("pushing %s...", tt.digest)
}
}
if msg != tt.wantMsg {
t.Errorf("got %q, want %q", msg, tt.wantMsg)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) { func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy // Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"} originalThink := &api.ThinkValue{Value: "original"}

3
docs/troubleshooting.md Normal file
View File

@@ -0,0 +1,3 @@
# Troubleshooting
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)

View File

@@ -118,9 +118,6 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
return return
} }
// Set think to nil when being used with Anthropic API to connect to tools like claude code
c.Set("relax_thinking", true)
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil { if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error())) c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))

View File

@@ -582,26 +582,3 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
}) })
} }
} }
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
gin.SetMode(gin.TestMode)
var flagSet bool
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
_, flagSet = c.Get("relax_thinking")
c.Status(http.StatusOK)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !flagSet {
t.Error("expected relax_thinking flag to be set in context")
}
}

View File

@@ -73,7 +73,7 @@ _build_darwin() {
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0" MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx . GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX . GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done done
} }
@@ -82,19 +82,19 @@ _sign_darwin() {
status "Creating universal binary..." status "Creating universal binary..."
mkdir -p dist/darwin mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
chmod +x dist/darwin/ollama chmod +x dist/darwin/ollama
chmod +x dist/darwin/ollama-mlx chmod +x dist/darwin/imagegen
if [ -n "$APPLE_IDENTITY" ]; then if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done done
# create a temporary zip for notarization # create a temporary zip for notarization
TEMP=$(mktemp -u).zip TEMP=$(mktemp -u).zip
ditto -c -k --keepParent dist/darwin/ollama "$TEMP" ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
rm -f "$TEMP" rm -f "$TEMP"
fi fi
@@ -154,25 +154,23 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F) lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/ cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/ cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
# Copy MLX metallib (architecture-independent, just use arm64 version)
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
else else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/ cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi fi
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
chmod a+x dist/Ollama.app/Contents/Resources/ollama chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign # Sign
if [ -n "$APPLE_IDENTITY" ]; then if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib} codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
@@ -180,11 +178,11 @@ _build_macapp() {
rm -f dist/Ollama-darwin.zip rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz (cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple # Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then if [ -n "$APPLE_IDENTITY" ]; then
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID" $(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app $(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
@@ -208,7 +206,7 @@ _build_macapp() {
rm -f dist/rw*.dmg rm -f dist/rw*.dmg
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID" $(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f stapler) staple dist/Ollama.dmg $(xcrun -f stapler) staple dist/Ollama.dmg
else else
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing" echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"

View File

@@ -48,12 +48,53 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
. .
fi fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory # Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then if echo $PLATFORM | grep "," > /dev/null ; then
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64" deduplicate_cuda_libs "./dist/linux_amd64"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64" deduplicate_cuda_libs "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
$(dirname $0)/deduplicate_cuda_libs.sh "./dist" deduplicate_cuda_libs "./dist"
fi fi
# buildx behavior changes for single vs. multiplatform # buildx behavior changes for single vs. multiplatform

View File

@@ -1,60 +0,0 @@
#!/bin/sh
#
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
# This script finds identical .so* files in mlx_cuda_* directories that exist
# in corresponding cuda_* directories and replaces them with symlinks.
#
set -eu
if [ $# -eq 0 ]; then
echo "ERROR: No directory specified" >&2
echo "Usage: $0 <base_directory>" >&2
exit 1
fi
base_dir="$1"
if [ ! -d "${base_dir}" ]; then
echo "ERROR: Directory ${base_dir} does not exist" >&2
exit 1
fi
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
echo "Deduplication complete"

View File

@@ -95,48 +95,11 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
} }
const ( const (
// numDownloadParts is the default number of concurrent download parts for standard downloads numDownloadParts = 16
numDownloadParts = 16
// numHFDownloadParts is the reduced number of concurrent download parts for HuggingFace
// downloads to avoid triggering rate limits (HTTP 429 errors). See GitHub issue #13297.
numHFDownloadParts = 4
minDownloadPartSize int64 = 100 * format.MegaByte minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte maxDownloadPartSize int64 = 1000 * format.MegaByte
) )
// isHuggingFaceURL returns true if the URL is from a HuggingFace domain.
// This includes:
// - huggingface.co (main domain)
// - *.huggingface.co (subdomains like cdn-lfs.huggingface.co)
// - hf.co (shortlink domain)
// - *.hf.co (CDN domains like cdn-lfs.hf.co, cdn-lfs3.hf.co)
func isHuggingFaceURL(u *url.URL) bool {
if u == nil {
return false
}
host := strings.ToLower(u.Hostname())
return host == "huggingface.co" ||
strings.HasSuffix(host, ".huggingface.co") ||
host == "hf.co" ||
strings.HasSuffix(host, ".hf.co")
}
// getNumDownloadParts returns the number of concurrent download parts to use
// for the given URL. HuggingFace URLs use reduced concurrency (default 4) to
// avoid triggering rate limits. This can be overridden via the OLLAMA_HF_CONCURRENCY
// environment variable. For non-HuggingFace URLs, returns the standard concurrency (16).
func getNumDownloadParts(u *url.URL) int {
if isHuggingFaceURL(u) {
if v := os.Getenv("OLLAMA_HF_CONCURRENCY"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return numHFDownloadParts
}
return numDownloadParts
}
func (p *blobDownloadPart) Name() string { func (p *blobDownloadPart) Name() string {
return strings.Join([]string{ return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N), p.blobDownload.Name, "partial", strconv.Itoa(p.N),
@@ -308,11 +271,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
} }
g, inner := errgroup.WithContext(ctx) g, inner := errgroup.WithContext(ctx)
concurrency := getNumDownloadParts(directURL) g.SetLimit(numDownloadParts)
if concurrency != numDownloadParts {
slog.Info(fmt.Sprintf("using reduced concurrency (%d) for HuggingFace download", concurrency))
}
g.SetLimit(concurrency)
for i := range b.Parts { for i := range b.Parts {
part := b.Parts[i] part := b.Parts[i]
if part.Completed.Load() == part.Size { if part.Completed.Load() == part.Size {

View File

@@ -1,194 +0,0 @@
package server
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsHuggingFaceURL(t *testing.T) {
tests := []struct {
name string
url string
expected bool
}{
{
name: "nil url",
url: "",
expected: false,
},
{
name: "huggingface.co main domain",
url: "https://huggingface.co/some/model",
expected: true,
},
{
name: "cdn-lfs.huggingface.co subdomain",
url: "https://cdn-lfs.huggingface.co/repos/abc/123",
expected: true,
},
{
name: "cdn-lfs3.hf.co CDN domain",
url: "https://cdn-lfs3.hf.co/repos/abc/123",
expected: true,
},
{
name: "hf.co shortlink domain",
url: "https://hf.co/model",
expected: true,
},
{
name: "uppercase HuggingFace domain",
url: "https://HUGGINGFACE.CO/model",
expected: true,
},
{
name: "mixed case HF domain",
url: "https://Cdn-Lfs.HF.Co/repos",
expected: true,
},
{
name: "ollama registry",
url: "https://registry.ollama.ai/v2/library/llama3",
expected: false,
},
{
name: "github.com",
url: "https://github.com/ollama/ollama",
expected: false,
},
{
name: "fake huggingface domain",
url: "https://nothuggingface.co/model",
expected: false,
},
{
name: "fake hf domain",
url: "https://nothf.co/model",
expected: false,
},
{
name: "huggingface in path not host",
url: "https://example.com/huggingface.co/model",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var u *url.URL
if tc.url != "" {
var err error
u, err = url.Parse(tc.url)
if err != nil {
t.Fatalf("failed to parse URL: %v", err)
}
}
got := isHuggingFaceURL(u)
assert.Equal(t, tc.expected, got)
})
}
}
func TestGetNumDownloadParts(t *testing.T) {
tests := []struct {
name string
url string
envValue string
expected int
description string
}{
{
name: "nil url returns default",
url: "",
envValue: "",
expected: numDownloadParts,
description: "nil URL should return standard concurrency",
},
{
name: "ollama registry returns default",
url: "https://registry.ollama.ai/v2/library/llama3",
envValue: "",
expected: numDownloadParts,
description: "Ollama registry should use standard concurrency",
},
{
name: "huggingface returns reduced default",
url: "https://huggingface.co/model/repo",
envValue: "",
expected: numHFDownloadParts,
description: "HuggingFace should use reduced concurrency",
},
{
name: "hf.co CDN returns reduced default",
url: "https://cdn-lfs3.hf.co/repos/abc/123",
envValue: "",
expected: numHFDownloadParts,
description: "HuggingFace CDN should use reduced concurrency",
},
{
name: "huggingface with env override",
url: "https://huggingface.co/model/repo",
envValue: "2",
expected: 2,
description: "OLLAMA_HF_CONCURRENCY should override default",
},
{
name: "huggingface with higher env override",
url: "https://huggingface.co/model/repo",
envValue: "8",
expected: 8,
description: "OLLAMA_HF_CONCURRENCY can be set higher than default",
},
{
name: "huggingface with invalid env (non-numeric)",
url: "https://huggingface.co/model/repo",
envValue: "invalid",
expected: numHFDownloadParts,
description: "Invalid OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "huggingface with invalid env (zero)",
url: "https://huggingface.co/model/repo",
envValue: "0",
expected: numHFDownloadParts,
description: "Zero OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "huggingface with invalid env (negative)",
url: "https://huggingface.co/model/repo",
envValue: "-1",
expected: numHFDownloadParts,
description: "Negative OLLAMA_HF_CONCURRENCY should fall back to default",
},
{
name: "non-huggingface ignores env",
url: "https://registry.ollama.ai/v2/library/llama3",
envValue: "2",
expected: numDownloadParts,
description: "OLLAMA_HF_CONCURRENCY should not affect non-HF URLs",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Set or clear the environment variable
if tc.envValue != "" {
t.Setenv("OLLAMA_HF_CONCURRENCY", tc.envValue)
}
var u *url.URL
if tc.url != "" {
var err error
u, err = url.Parse(tc.url)
if err != nil {
t.Fatalf("failed to parse URL: %v", err)
}
}
got := getNumDownloadParts(u)
assert.Equal(t, tc.expected, got, tc.description)
})
}
}

View File

@@ -47,37 +47,13 @@ func (m *Manifest) Remove() error {
} }
func (m *Manifest) RemoveLayers() error { func (m *Manifest) RemoveLayers() error {
ms, err := Manifests(true)
if err != nil {
return err
}
// Build set of digests still in use by other manifests
inUse := make(map[string]struct{})
for _, other := range ms {
for _, layer := range append(other.Layers, other.Config) {
if layer.Digest != "" {
inUse[layer.Digest] = struct{}{}
}
}
}
// Remove layers not used by any other manifest
for _, layer := range append(m.Layers, m.Config) { for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == "" { if layer.Digest != "" {
continue if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
} slog.Debug("layer does not exist", "digest", layer.Digest)
if _, used := inUse[layer.Digest]; used { } else if err != nil {
continue return err
} }
blob, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
} }
} }

View File

@@ -1124,15 +1124,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
QuantizationLevel: m.Config.FileType, QuantizationLevel: m.Config.FileType,
} }
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
if req.System != "" { if req.System != "" {
m.System = req.System m.System = req.System
} }
@@ -1215,10 +1206,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil return resp, nil
} }
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2072,14 +2059,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
} else { } else {
if req.Think != nil && req.Think.Bool() { if req.Think != nil && req.Think.Bool() {
// Set think to nil when being used with Anthropic API to connect to tools like claude code c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
if _, ok := c.Get("relax_thinking"); ok { return
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
req.Think = nil
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
}
} }
} }

View File

@@ -1,50 +1,24 @@
# Experimental Features # Experimental Features
## MLX Backend ## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx) We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
### Building ollama-mlx ```
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
#### macOS (Apple Silicon and Intel)
```bash
# Build MLX backend libraries
cmake --preset MLX cmake --preset MLX
cmake --build --preset MLX --parallel cmake --build --preset MLX --parallel
cmake --install build --component MLX cmake --install --component MLX
go build -tags mlx .
# Build ollama-mlx binary
go build -tags mlx -o ollama-mlx .
``` ```
#### Linux (CUDA) On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
```bash
# Build MLX backend libraries with CUDA support
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
go build -tags mlx -o ollama-mlx .
```
#### Using build scripts
The build scripts automatically create the `ollama-mlx` binary:
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
## Image Generation ## Image Generation
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled. Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -41,7 +41,6 @@ var optionLabels = []string{
var toolDisplayNames = map[string]string{ var toolDisplayNames = map[string]string{
"bash": "Bash", "bash": "Bash",
"web_search": "Web Search", "web_search": "Web Search",
"web_fetch": "Web Fetch",
} }
// ToolDisplayName returns the human-readable display name for a tool. // ToolDisplayName returns the human-readable display name for a tool.
@@ -566,16 +565,6 @@ func formatToolDisplay(toolName string, args map[string]any) string {
} }
} }
// For web fetch, show URL and internet notice
if toolName == "web_fetch" {
if url, ok := args["url"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
sb.WriteString("Uses internet via ollama.com")
return sb.String()
}
}
// Generic display // Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", displayName)) sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
if len(args) > 0 { if len(args) > 0 {
@@ -1028,16 +1017,6 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
} }
} }
if toolName == "web_fetch" {
if url, ok := args["url"].(string); ok {
// Truncate long URLs
if len(url) > 50 {
url = url[:47] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
}
}
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName) return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
} }

View File

@@ -9,7 +9,6 @@ import (
"net/url" "net/url"
"os" "os"
"os/signal" "os/signal"
"slices"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@@ -25,14 +24,6 @@ import (
"github.com/ollama/ollama/x/tools" "github.com/ollama/ollama/x/tools"
) )
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants // Tool output capping constants
const ( const (
// localModelTokenLimit is the token limit for local models (smaller context). // localModelTokenLimit is the token limit for local models (smaller context).
@@ -139,7 +130,6 @@ type RunOptions struct {
KeepAlive *api.Duration KeepAlive *api.Duration
Think *api.ThinkValue Think *api.ThinkValue
HideThinking bool HideThinking bool
Verbose bool
// Agent fields (managed externally for session persistence) // Agent fields (managed externally for session persistence)
Tools *tools.Registry Tools *tools.Registry
@@ -188,7 +178,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagClosed bool = false var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit var consecutiveErrors int // Track consecutive 500 errors for retry limit
var latest api.ChatResponse
role := "assistant" role := "assistant"
messages := opts.Messages messages := opts.Messages
@@ -198,7 +187,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
p.StopAndClear() p.StopAndClear()
} }
latest = response
role = response.Message.Role role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking { if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened { if !thinkTagOpened {
@@ -495,10 +483,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Println() fmt.Println()
} }
if opts.Verbose {
latest.Summary()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
} }
@@ -650,8 +634,7 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
// GenerateInteractive runs an interactive agent session. // GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set. // This is called from cmd.go when --experimental flag is set.
// If yoloMode is true, all tool approvals are skipped. // If yoloMode is true, all tool approvals are skipped.
// If enableWebsearch is true, the web search tool is registered. func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
scanner, err := readline.New(readline.Prompt{ scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
@@ -677,12 +660,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if supportsTools { if supportsTools {
toolRegistry = tools.DefaultRegistry() toolRegistry = tools.DefaultRegistry()
// Register web search and web fetch tools if enabled via flag
if enableWebsearch {
toolRegistry.RegisterWebSearch()
toolRegistry.RegisterWebFetch()
}
if toolRegistry.Has("bash") { if toolRegistry.Has("bash") {
fmt.Fprintln(os.Stderr) fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.") fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
@@ -690,11 +667,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr) fmt.Fprintln(os.Stderr)
} }
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
fmt.Fprintln(os.Stderr)
}
if yoloMode { if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n") fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
} }
@@ -705,9 +677,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message var messages []api.Message
var sb strings.Builder var sb strings.Builder
var format string
var system string
var multiline MultilineState = MultilineNone
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@@ -719,39 +688,13 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if line == "" { if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.") fmt.Println("\nUse Ctrl + d or /bye to exit.")
} }
scanner.Prompt.UseAlt = false
sb.Reset() sb.Reset()
multiline = MultilineNone
continue continue
case err != nil: case err != nil:
return err return err
} }
switch { switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil return nil
case strings.HasPrefix(line, "/clear"): case strings.HasPrefix(line, "/clear"):
@@ -764,10 +707,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
continue continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"): case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load Load a different model")
fmt.Fprintln(os.Stderr, " /save Save session as a model")
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals") fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals") fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /bye Exit")
@@ -777,303 +716,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output") fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
continue continue
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) > 1 {
switch args[1] {
case "history":
scanner.HistoryEnable()
case "nohistory":
scanner.HistoryDisable()
case "wordwrap":
wordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
wordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
if err := cmd.Flags().Set("verbose", "true"); err != nil {
return err
}
fmt.Println("Set 'verbose' mode.")
case "quiet":
if err := cmd.Flags().Set("verbose", "false"); err != nil {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
thinkValue := api.ThinkValue{Value: true}
var maybeLevel string
if len(args) > 2 {
maybeLevel = args[2]
}
if maybeLevel != "" {
thinkValue.Value = maybeLevel
}
think = &thinkValue
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
if maybeLevel != "" {
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
} else {
fmt.Println("Set 'think' mode.")
}
case "nothink":
think = &api.ThinkValue{Value: false}
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
format = ""
fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
fmt.Println("Usage: /set parameter <name> <value>")
continue
}
params := args[3:]
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
continue
}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
}
continue
case strings.HasPrefix(line, "/show"):
args := strings.Fields(line)
if len(args) > 1 {
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.ShowRequest{
Name: modelName,
Options: options,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model")
continue
}
switch args[1] {
case "info":
fmt.Fprintf(os.Stderr, " Model\n")
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
if resp.Details.Family != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
}
if resp.Details.ParameterSize != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
}
if resp.Details.QuantizationLevel != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
}
if len(resp.Capabilities) > 0 {
caps := make([]string, len(resp.Capabilities))
for i, c := range resp.Capabilities {
caps[i] = string(c)
}
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
}
fmt.Fprintln(os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
} else {
fmt.Println(resp.License)
}
case "modelfile":
fmt.Println(resp.Modelfile)
case "parameters":
fmt.Println("Model defined parameters:")
if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified.")
} else {
for _, l := range strings.Split(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l)
}
}
if len(options) > 0 {
fmt.Println("\nUser defined parameters:")
for k, v := range options {
fmt.Printf(" %-30s %v\n", k, v)
}
}
case "system":
switch {
case system != "":
fmt.Println(system + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Println("No system message was specified for this model.")
}
case "template":
if resp.Template != "" {
fmt.Println(resp.Template)
} else {
fmt.Println("No prompt template was specified for this model.")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
}
continue
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /load <modelname>")
continue
}
newModelName := args[1]
fmt.Printf("Loading model '%s'\n", newModelName)
// Create progress spinner
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// Get client
client, err := api.ClientFromEnvironment()
if err != nil {
p.StopAndClear()
fmt.Println("error: couldn't connect to ollama server")
continue
}
// Check if model exists and get its info
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else {
fmt.Printf("error: %v\n", err)
}
continue
}
// For cloud models, no need to preload
if info.RemoteHost == "" {
// Preload the model by sending an empty generate request
req := &api.GenerateRequest{
Model: newModelName,
Think: think,
}
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
return nil
})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
} else {
fmt.Printf("error loading model: %v\n", err)
}
continue
}
}
p.StopAndClear()
modelName = newModelName
messages = []api.Message{}
approval.Reset()
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.CreateRequest{
Model: args[1],
From: modelName,
Parameters: options,
Messages: messages,
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Printf("error: %v\n", err)
continue
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/"): case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0]) fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
continue continue
@@ -1081,16 +723,14 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line) sb.WriteString(line)
} }
if sb.Len() > 0 && multiline == MultilineNone { if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()} newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage) messages = append(messages, newMessage)
verbose, _ := cmd.Flags().GetBool("verbose")
opts := RunOptions{ opts := RunOptions{
Model: modelName, Model: modelName,
Messages: messages, Messages: messages,
WordWrap: wordWrap, WordWrap: wordWrap,
Format: format,
Options: options, Options: options,
Think: think, Think: think,
HideThinking: hideThinking, HideThinking: hideThinking,
@@ -1098,7 +738,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Tools: toolRegistry, Tools: toolRegistry,
Approval: approval, Approval: approval,
YoloMode: yoloMode, YoloMode: yoloMode,
Verbose: verbose,
} }
assistant, err := Chat(cmd.Context(), opts) assistant, err := Chat(cmd.Context(), opts)

View File

@@ -234,17 +234,3 @@ ollama create z-image
3. Copy config files (*.json) as config layers 3. Copy config files (*.json) as config layers
4. Write manifest 4. Write manifest
``` ```
## FP8 Quantization
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
### Usage
```bash
cd ./weights/Z-Image-Turbo
ollama create z-image-fp8 --quantize fp8
```
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.

View File

@@ -1,8 +1,10 @@
package api package api
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"os"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -99,10 +101,10 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
var imageBase64 string var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) { err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done { if resp.Done {
imageBase64 = extractBase64(resp.Content) imagePath = extractPath(resp.Content)
} else { } else {
progress := parseProgress(resp.Content) progress := parseProgress(resp.Content)
if progress.Total > 0 { if progress.Total > 0 {
@@ -116,14 +118,14 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
return return
} }
c.SSEvent("done", buildResponse(imageBase64, format)) c.SSEvent("done", buildResponse(imagePath, format))
} }
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) { func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) { err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done { if resp.Done {
imageBase64 = extractBase64(resp.Content) imagePath = extractPath(resp.Content)
} }
}) })
if err != nil { if err != nil {
@@ -131,7 +133,7 @@ func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.
return return
} }
c.JSON(http.StatusOK, buildResponse(imageBase64, format)) c.JSON(http.StatusOK, buildResponse(imagePath, format))
} }
func parseSize(size string) (int32, int32) { func parseSize(size string) (int32, int32) {
@@ -150,9 +152,9 @@ func parseSize(size string) (int32, int32) {
return int32(w), int32(h) return int32(w), int32(h)
} }
func extractBase64(content string) string { func extractPath(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") { if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
return content[13:] return strings.TrimSpace(content[idx+16:])
} }
return "" return ""
} }
@@ -163,21 +165,23 @@ func parseProgress(content string) ImageProgressEvent {
return ImageProgressEvent{Step: step, Total: total} return ImageProgressEvent{Step: step, Total: total}
} }
func buildResponse(imageBase64, format string) ImageGenerationResponse { func buildResponse(imagePath, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{ resp := ImageGenerationResponse{
Created: time.Now().Unix(), Created: time.Now().Unix(),
Data: make([]ImageData, 1), Data: make([]ImageData, 1),
} }
if imageBase64 == "" { if imagePath == "" {
return resp return resp
} }
if format == "url" { if format == "url" {
// URL format not supported when using base64 transfer resp.Data[0].URL = "file://" + imagePath
resp.Data[0].B64JSON = imageBase64
} else { } else {
resp.Data[0].B64JSON = imageBase64 data, err := os.ReadFile(imagePath)
if err == nil {
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
}
} }
return resp return resp

View File

@@ -1,197 +0,0 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
// It caches the transformer output and reuses it when timestep values
// are similar between consecutive steps.
//
// For CFG (classifier-free guidance), it caches pos and neg predictions
// separately and always computes CFG fresh to avoid error amplification.
//
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
// https://github.com/ali-vilab/TeaCache
type TeaCache struct {
// Cached transformer output from last computed step (non-CFG mode)
cachedOutput *mlx.Array
// Cached CFG outputs (pos and neg separately)
cachedPosOutput *mlx.Array
cachedNegOutput *mlx.Array
// Previous timestep value for difference calculation
prevTimestep float32
// Accumulated difference for rescaling
accumulatedDiff float32
// Configuration
threshold float32 // Threshold for recomputation decision
rescaleFactor float32 // Model-specific rescaling factor
skipEarlySteps int // Number of early steps to never cache
// Statistics
cacheHits int
cacheMisses int
}
// TeaCacheConfig holds configuration for TeaCache.
type TeaCacheConfig struct {
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
// Recommended: 0.05-0.15 for image models
Threshold float32
// Rescale factor to adjust timestep embedding differences.
// Model-specific, typically 1.0-2.0
RescaleFactor float32
// SkipEarlySteps: number of early steps to always compute (never cache).
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
SkipEarlySteps int
}
// DefaultTeaCacheConfig returns default configuration for TeaCache.
func DefaultTeaCacheConfig() *TeaCacheConfig {
return &TeaCacheConfig{
Threshold: 0.1,
RescaleFactor: 1.0,
}
}
// NewTeaCache creates a new TeaCache instance.
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
if cfg == nil {
cfg = DefaultTeaCacheConfig()
}
return &TeaCache{
threshold: cfg.Threshold,
rescaleFactor: cfg.RescaleFactor,
skipEarlySteps: cfg.SkipEarlySteps,
}
}
// ShouldCompute determines if we should compute the full forward pass
// or reuse the cached output based on timestep similarity.
//
// Algorithm:
// 1. First step always computes
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
// 3. If accumulated difference > threshold, compute new output
// 4. Otherwise, reuse cached output
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
// Always compute early steps (critical for structure)
// Check both regular cache and CFG cache
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
return true
}
// Compute absolute difference between current and previous timestep
diff := timestep - tc.prevTimestep
if diff < 0 {
diff = -diff
}
// Apply rescaling factor
scaledDiff := diff * tc.rescaleFactor
// Accumulate difference (helps track drift over multiple cached steps)
tc.accumulatedDiff += scaledDiff
// Decision based on accumulated difference
if tc.accumulatedDiff > tc.threshold {
tc.accumulatedDiff = 0 // Reset accumulator
return true
}
return false
}
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
// Free previous cached output
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
}
// Store new cached values
tc.cachedOutput = output
tc.prevTimestep = timestep
tc.cacheMisses++
}
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
// This allows CFG to be computed fresh each step, avoiding error amplification.
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
// Free previous cached outputs
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
}
// Store new cached values
tc.cachedPosOutput = posOutput
tc.cachedNegOutput = negOutput
tc.prevTimestep = timestep
tc.cacheMisses++
}
// GetCached returns the cached output (non-CFG mode).
func (tc *TeaCache) GetCached() *mlx.Array {
tc.cacheHits++
return tc.cachedOutput
}
// GetCFGCached returns cached pos and neg outputs for CFG mode.
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
tc.cacheHits++
return tc.cachedPosOutput, tc.cachedNegOutput
}
// HasCFGCache returns true if CFG cache is available.
func (tc *TeaCache) HasCFGCache() bool {
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
}
// Arrays returns all arrays that should be kept alive.
func (tc *TeaCache) Arrays() []*mlx.Array {
var arrays []*mlx.Array
if tc.cachedOutput != nil {
arrays = append(arrays, tc.cachedOutput)
}
if tc.cachedPosOutput != nil {
arrays = append(arrays, tc.cachedPosOutput)
}
if tc.cachedNegOutput != nil {
arrays = append(arrays, tc.cachedNegOutput)
}
return arrays
}
// Stats returns cache hit/miss statistics.
func (tc *TeaCache) Stats() (hits, misses int) {
return tc.cacheHits, tc.cacheMisses
}
// Free releases all cached arrays.
func (tc *TeaCache) Free() {
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
tc.cachedOutput = nil
}
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
tc.cachedPosOutput = nil
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
tc.cachedNegOutput = nil
}
}

View File

@@ -44,64 +44,62 @@ func DefaultOptions() ImageGenOptions {
} }
} }
// ModelInfo contains metadata about an image generation model. // Show displays information about an image generation model.
type ModelInfo struct { func Show(modelName string, w io.Writer) error {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName) manifest, err := LoadManifest(modelName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err) return fmt.Errorf("failed to load manifest: %w", err)
} }
info := &ModelInfo{} // Count total size
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalSize += layer.Size
}
}
// Read model_index.json for architecture, parameter count, and quantization // Read model_index.json for architecture
var architecture string
if data, err := manifest.ReadConfig("model_index.json"); err == nil { if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct { var index struct {
Architecture string `json:"architecture"` Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
} }
if json.Unmarshal(data, &index) == nil { if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
} }
} }
// Fallback: detect quantization from tensor names if not in config // Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
if info.Quantization == "" { paramCount := totalSize / 2
for _, layer := range manifest.Manifest.Layers { paramStr := formatParamCount(paramCount)
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config // Print Model info
if info.ParameterCount == 0 { fmt.Fprintln(w, " Model")
var totalSize int64 if architecture != "" {
for _, layer := range manifest.Manifest.Layers { fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
} }
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
fmt.Fprintln(w)
return info, nil // Print Capabilities
fmt.Fprintln(w, " Capabilities")
fmt.Fprintf(w, " %s\n", "image")
fmt.Fprintln(w)
return nil
}
// formatParamCount formats parameter count as human-readable string.
func formatParamCount(count int64) string {
if count >= 1_000_000_000 {
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
}
if count >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
}
return fmt.Sprintf("%d", count)
} }
// RegisterFlags adds image generation flags to the given command. // RegisterFlags adds image generation flags to the given command.
@@ -123,6 +121,11 @@ func RegisterFlags(cmd *cobra.Command) {
// Returns true if it handled the request, false if the caller should continue with normal flow. // Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative // Supports flags: --width, --height, --steps, --seed, --negative
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error { func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Verify it's a valid image gen model
if ResolveModelName(name) == "" {
return fmt.Errorf("unknown image generation model: %s", name)
}
// Get options from flags (with env var defaults) // Get options from flags (with env var defaults)
opts := DefaultOptions() opts := DefaultOptions()
if cmd != nil && cmd.Flags() != nil { if cmd != nil && cmd.Flags() != nil {
@@ -180,7 +183,8 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
p.Add("", spinner) p.Add("", spinner)
var stepBar *progress.StepBar var stepBar *progress.StepBar
var imageBase64 string var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error { err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response content := resp.Response
@@ -199,9 +203,11 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return nil return nil
} }
// Handle final response with base64 image data // Handle final response with image path
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") { if resp.Done && strings.Contains(content, "Image saved to:") {
imageBase64 = content[13:] if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
} }
return nil return nil
@@ -212,27 +218,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err return err
} }
if imageBase64 != "" { if imagePath != "" {
// Decode base64 and save to CWD displayImageInTerminal(imagePath)
imageData, err := base64.StdEncoding.DecodeString(imageBase64) fmt.Printf("Image saved to: %s\n", imagePath)
if err != nil {
return fmt.Errorf("failed to decode image: %w", err)
}
// Create filename from prompt
safeName := sanitizeFilename(prompt)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
return fmt.Errorf("failed to save image: %w", err)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
} }
return nil return nil
@@ -318,7 +306,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
p.Add("", spinner) p.Add("", spinner)
var stepBar *progress.StepBar var stepBar *progress.StepBar
var imageBase64 string var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error { err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response content := resp.Response
@@ -338,9 +326,11 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
return nil return nil
} }
// Handle final response with base64 image data // Handle final response with image path
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") { if resp.Done && strings.Contains(content, "Image saved to:") {
imageBase64 = content[13:] if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
} }
return nil return nil
@@ -352,30 +342,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
continue continue
} }
// Save image to current directory with descriptive name // Copy image to current directory with descriptive name
if imageBase64 != "" { if imagePath != "" {
// Decode base64 image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
continue
}
// Create filename from prompt (sanitized) // Create filename from prompt (sanitized)
safeName := sanitizeFilename(line) safeName := sanitizeFilename(line)
if len(safeName) > 50 { if len(safeName) > 50 {
safeName = safeName[:50] safeName = safeName[:50]
} }
timestamp := time.Now().Format("20060102-150405") timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp) newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil { // Copy file to CWD
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err) if err := copyFile(imagePath, newName); err != nil {
continue fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
} else {
displayImageInTerminal(newName)
fmt.Printf("Image saved to: %s\n", newName)
} }
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
} }
fmt.Println() fmt.Println()
@@ -396,6 +381,24 @@ func sanitizeFilename(s string) string {
return result.String() return result.String()
} }
// copyFile copies a file from src to dst.
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// printInteractiveHelp prints help for interactive mode commands. // printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) { func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:") fmt.Fprintln(os.Stderr, "Commands:")
@@ -506,7 +509,10 @@ func displayImageInTerminal(imagePath string) bool {
// Send in chunks for large images // Send in chunks for large images
const chunkSize = 4096 const chunkSize = 4096
for i := 0; i < len(encoded); i += chunkSize { for i := 0; i < len(encoded); i += chunkSize {
end := min(i+chunkSize, len(encoded)) end := i + chunkSize
if end > len(encoded) {
end = len(encoded)
}
chunk := encoded[i:end] chunk := encoded[i:end]
if i == 0 { if i == 0 {

View File

@@ -17,7 +17,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
@@ -28,15 +31,41 @@ import (
const MinOllamaVersion = "0.14.0" const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory. // CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API. func CreateModel(modelName, modelDir string, p *progress.Progress) error {
// If quantize is "fp8", weights will be quantized to mxfp8 format during import. return CreateModelFromModelfile(modelName, modelDir, nil, p)
// }
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error { // CreateModelFromModelfile imports a tensor-based model using Modelfile commands.
// Extracts LICENSE, REQUIRES, and PARAMETER commands from the Modelfile.
func CreateModelFromModelfile(modelName, modelDir string, commands []parser.Command, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) { if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir) return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
} }
// Extract metadata from Modelfile commands
var licenses []string
var requires string
params := make(map[string]any)
for _, c := range commands {
switch c.Name {
case "license":
licenses = append(licenses, c.Args)
case "requires":
requires = c.Args
case "model":
// skip - already handled by caller
default:
// Treat as parameter (steps, width, height, seed, etc.)
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
if err == nil {
for k, v := range ps {
params[k] = v
}
}
}
}
status := "importing image generation model" status := "importing image generation model"
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner) p.Add("imagegen", spinner)
@@ -47,8 +76,6 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
if err != nil { if err != nil {
return imagegen.LayerInfo{}, err return imagegen.LayerInfo{}, err
} }
layer.Name = name
return imagegen.LayerInfo{ return imagegen.LayerInfo{
Digest: layer.Digest, Digest: layer.Digest,
Size: layer.Size, Size: layer.Size,
@@ -57,79 +84,17 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
}, nil }, nil
} }
// Create tensor layer callback for individual tensors // Create tensor layer callback
// name is path-style: "component/tensor_name" createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor) layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil { if err != nil {
return nil, err return imagegen.LayerInfo{}, err
} }
return imagegen.LayerInfo{
return []imagegen.LayerInfo{ Digest: layer.Digest,
{ Size: layer.Size,
Digest: layer.Digest, MediaType: layer.MediaType,
Size: layer.Size, Name: name,
MediaType: layer.MediaType,
Name: name,
},
}, nil }, nil
} }
@@ -140,24 +105,27 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
return fmt.Errorf("invalid model name: %s", modelName) return fmt.Errorf("invalid model name: %s", modelName)
} }
// Create a proper config blob with version requirement // Use Modelfile REQUIRES if specified, otherwise use minimum
if requires == "" {
requires = MinOllamaVersion
}
configData := model.ConfigV2{ configData := model.ConfigV2{
ModelFormat: "safetensors", ModelFormat: "safetensors",
Capabilities: []string{"image"}, Capabilities: []string{"image"},
Requires: MinOllamaVersion, Requires: requires,
} }
configJSON, err := json.Marshal(configData) configJSON, err := json.Marshal(configData)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal config: %w", err) return fmt.Errorf("failed to marshal config: %w", err)
} }
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json") configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil { if err != nil {
return fmt.Errorf("failed to create config layer: %w", err) return fmt.Errorf("failed to create config layer: %w", err)
} }
// Convert LayerInfo to server.Layer (include the original model_index.json in layers) // Convert to server.Layer
serverLayers := make([]server.Layer, len(layers)) serverLayers := make([]server.Layer, len(layers))
for i, l := range layers { for i, l := range layers {
serverLayers[i] = server.Layer{ serverLayers[i] = server.Layer{
@@ -168,10 +136,31 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
} }
} }
// Add license layers
for _, license := range licenses {
layer, err := server.NewLayer(strings.NewReader(license), "application/vnd.ollama.image.license")
if err != nil {
return fmt.Errorf("failed to create license layer: %w", err)
}
serverLayers = append(serverLayers, layer)
}
// Add parameters layer
if len(params) > 0 {
paramsJSON, err := json.Marshal(params)
if err != nil {
return fmt.Errorf("failed to marshal parameters: %w", err)
}
layer, err := server.NewLayer(bytes.NewReader(paramsJSON), "application/vnd.ollama.image.params")
if err != nil {
return fmt.Errorf("failed to create params layer: %w", err)
}
serverLayers = append(serverLayers, layer)
}
return server.WriteManifest(name, configLayer, serverLayers) return server.WriteManifest(name, configLayer, serverLayers)
} }
// Progress callback
progressFn := func(msg string) { progressFn := func(msg string) {
spinner.Stop() spinner.Stop()
status = msg status = msg
@@ -179,7 +168,7 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
p.Add("imagegen", spinner) p.Add("imagegen", spinner)
} }
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn) err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {
return err return err

View File

@@ -0,0 +1,35 @@
package client
import (
"testing"
"github.com/ollama/ollama/parser"
)
func TestCreateModelFromModelfileExtractsMetadata(t *testing.T) {
// Test that the command parsing works correctly
commands := []parser.Command{
{Name: "model", Args: "./weights/test"},
{Name: "license", Args: "Apache-2.0"},
{Name: "requires", Args: "0.15.0"},
{Name: "num_predict", Args: "12"},
{Name: "seed", Args: "42"},
}
// We can't easily test the full function without a real model dir,
// but we can verify the commands are valid parser.Command types
for _, c := range commands {
if c.Name == "" {
t.Error("Command name should not be empty")
}
}
}
func TestMinOllamaVersion(t *testing.T) {
if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty")
}
if MinOllamaVersion[0] < '0' || MinOllamaVersion[0] > '9' {
t.Errorf("MinOllamaVersion should start with a number, got %q", MinOllamaVersion)
}
}

View File

@@ -1,120 +0,0 @@
//go:build mlx
package client
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
// and returns safetensors data for the quantized weights, scales, and biases.
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
}
// Convert to BFloat16 if needed (quantize expects float type)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize with affine mode: group_size=32, bits=8
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
}
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
func QuantizeSupported() bool {
return true
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}

View File

@@ -1,18 +0,0 @@
//go:build !mlx
package client
import (
"fmt"
"io"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

View File

@@ -67,9 +67,6 @@ func main() {
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)") flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)") negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing") cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
flag.Parse() flag.Parse()
@@ -102,17 +99,13 @@ func main() {
} }
var img *mlx.Array var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{ img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt, Prompt: *prompt,
NegativePrompt: *negativePrompt, Width: int32(*width),
CFGScale: float32(*cfgScale), Height: int32(*height),
Width: int32(*width), Steps: *steps,
Height: int32(*height), Seed: *seed,
Steps: *steps, CapturePath: *gpuCapture,
Seed: *seed, LayerCache: *layerCache,
CapturePath: *gpuCapture,
TeaCache: *teaCache,
TeaCacheThreshold: float32(*teaCacheThreshold),
FusedQKV: *fusedQKV,
}) })
if err == nil { if err == nil {
err = saveImageArray(img, *out) err = saveImageArray(img, *out)

View File

@@ -40,12 +40,10 @@ type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo)
// CreateModel imports an image generation model from a directory. // CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication. // Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles. // Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo var layers []LayerInfo
var configLayer LayerInfo var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each // Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"} components := []string{"text_encoder", "transformer", "vae"}
@@ -76,11 +74,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
} }
tensorNames := extractor.ListTensors() tensorNames := extractor.ListTensors()
quantizeMsg := "" fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
if quantize == "fp8" && component != "vae" {
quantizeMsg = ", quantizing to fp8"
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames { for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName) td, err := extractor.GetTensor(tensorName)
@@ -89,30 +83,16 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err) return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
} }
// Count parameters from original tensor shape
if len(td.Shape) > 0 {
numElements := int64(1)
for _, dim := range td.Shape {
numElements *= int64(dim)
}
totalParams += numElements
}
// Store as minimal safetensors format (88 bytes header overhead) // Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors // This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name" // Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName fullName := component + "/" + tensorName
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
// Determine if this tensor should be quantized
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
if err != nil { if err != nil {
extractor.Close() extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err) return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
} }
layers = append(layers, newLayers...) layers = append(layers, layer)
} }
extractor.Close() extractor.Close()
@@ -142,7 +122,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
var r io.Reader var r io.Reader
// For model_index.json, normalize to Ollama format and add metadata // For model_index.json, normalize to Ollama format
if cfgPath == "model_index.json" { if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath) data, err := os.ReadFile(fullPath)
if err != nil { if err != nil {
@@ -161,16 +141,6 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
} }
delete(cfg, "_diffusers_version") delete(cfg, "_diffusers_version")
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info
if quantize == "fp8" {
cfg["quantization"] = "FP8"
} else {
cfg["quantization"] = "BF16"
}
data, err = json.MarshalIndent(cfg, "", " ") data, err = json.MarshalIndent(cfg, "", " ")
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err) return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)

View File

@@ -60,12 +60,9 @@ func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
} }
// Transform to [H, W, C] for image conversion // Transform to [H, W, C] for image conversion
// Free intermediate arrays to avoid memory leak img := mlx.Squeeze(arr, 0)
squeezed := mlx.Squeeze(arr, 0) img = mlx.Transpose(img, 1, 2, 0)
transposed := mlx.Transpose(squeezed, 1, 2, 0) img = mlx.Contiguous(img)
squeezed.Free()
img := mlx.Contiguous(transposed)
transposed.Free()
mlx.Eval(img) mlx.Eval(img)
imgShape := img.Shape() imgShape := img.Shape()

View File

@@ -607,11 +607,6 @@ func (a *Array) Valid() bool {
return a != nil && a.c.ctx != nil return a != nil && a.c.ctx != nil
} }
// Kept returns true if the array is marked to survive Eval() cleanup.
func (a *Array) Kept() bool {
return a != nil && a.kept
}
func int32ToCInt(s []int32) *C.int { func int32ToCInt(s []int32) *C.int {
if len(s) == 0 { if len(s) == 0 {
return nil return nil
@@ -1485,44 +1480,6 @@ func (a *Array) ItemInt32() int32 {
return int32(val) return int32(val)
} }
// Bytes copies the raw bytes out of the array without type conversion.
// Works with common dtypes (float32, int32, uint32, uint8).
// For non-contiguous arrays, call Contiguous() first.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) Bytes() []byte {
cleanup()
nbytes := a.Nbytes()
if nbytes == 0 {
return nil
}
// Get raw pointer based on dtype
var ptr unsafe.Pointer
switch a.Dtype() {
case DtypeFloat32:
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
case DtypeInt32:
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
case DtypeUint32:
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
case DtypeUint8:
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
default:
// For other types (bf16, f16, etc), convert to float32
arr := AsType(a, DtypeFloat32)
arr.Eval()
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
nbytes = arr.Nbytes()
}
if ptr == nil {
return nil
}
data := make([]byte, nbytes)
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
return data
}
// ============ Utility ============ // ============ Utility ============
// String returns a string representation // String returns a string representation
@@ -1701,34 +1658,6 @@ func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_string_free(s.metadata) C.mlx_map_string_to_string_free(s.metadata)
} }
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
// This correctly handles all dtypes including uint32 for quantized weights.
func SaveSafetensors(path string, arrays map[string]*Array) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
// Create the map
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
// Add each array to the map
for name, arr := range arrays {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
C.free(unsafe.Pointer(cName))
}
// Create empty metadata (optional)
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
// Save
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}
// ============ NPY Loading ============ // ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file // LoadNpy loads a numpy array from an npy file
@@ -2057,8 +1986,7 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
// Returns (quantized_weights, scales, biases). // Returns (quantized_weights, scales, biases).
// groupSize: number of elements quantized together (default 64) // groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4) // bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default), "mxfp4", or "mxfp8" // mode: "affine" (default) or "mxfp4"
// Note: mxfp8 mode returns nil biases (only weights and scales)
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) { func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode) cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode)) defer C.free(unsafe.Pointer(cMode))
@@ -2067,21 +1995,14 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
res := C.mlx_vector_array_new() res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream()) C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
// Result is a vector of arrays: [weights, scales, biases?] // Result is a vector of 3 arrays: [weights, scales, biases]
// mxfp8 mode returns only 2 elements (no biases)
vecSize := int(C.mlx_vector_array_size(res))
var w0, w1, w2 C.mlx_array var w0, w1, w2 C.mlx_array
C.mlx_vector_array_get(&w0, res, 0) C.mlx_vector_array_get(&w0, res, 0)
C.mlx_vector_array_get(&w1, res, 1) C.mlx_vector_array_get(&w1, res, 1)
if vecSize >= 3 { C.mlx_vector_array_get(&w2, res, 2)
C.mlx_vector_array_get(&w2, res, 2)
}
C.mlx_vector_array_free(res) C.mlx_vector_array_free(res)
if vecSize >= 3 { return newArray(w0), newArray(w1), newArray(w2)
return newArray(w0), newArray(w1), newArray(w2)
}
return newArray(w0), newArray(w1), nil
} }
// Dequantize reconstructs weights from quantized form. // Dequantize reconstructs weights from quantized form.

View File

@@ -222,14 +222,6 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
mlx.Keep(posEmb, negEmb) mlx.Keep(posEmb, negEmb)
} }
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler // Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig()) scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen) scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
@@ -272,19 +264,10 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
var output *mlx.Array var output *mlx.Array
if useCFG { if useCFG {
// CFG Batching: single forward pass with batch=2 // True CFG: run twice and combine with norm rescaling
// Note: layer caching with CFG is not supported yet (would need 2 caches) // Note: layer caching with CFG is not supported yet (would need 2 caches)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1}) posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
batchedTimestep := mlx.Tile(timestep, []int32{2}) negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
diff := mlx.Sub(posOutput, negOutput) diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale) scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
@@ -322,9 +305,6 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
if negEmb != nil { if negEmb != nil {
negEmb.Free() negEmb.Free()
} }
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free() ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free() ropeCache.TxtFreqs.Free()
if stepCache != nil { if stepCache != nil {

View File

@@ -241,14 +241,6 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(posEmb, negEmb) mlx.Eval(posEmb, negEmb)
} }
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Encode all input images to latents and concatenate // Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...") fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages)) allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
@@ -299,18 +291,11 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
var output *mlx.Array var output *mlx.Array
if useCFG { if useCFG {
// CFG Batching: single forward pass with batch=2 posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Tile inputs: [1, L, D] -> [2, L, D] negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale) output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else { } else {
@@ -332,9 +317,6 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
if negEmb != nil { if negEmb != nil {
negEmb.Free() negEmb.Free()
} }
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free() ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free() ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free() imageLatentsPacked.Free()

View File

@@ -28,12 +28,12 @@ type Qwen3Config struct {
// Qwen3Attention implements Qwen3 attention with QK norms // Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct { type Qwen3Attention struct {
QProj nn.LinearLayer `weight:"q_proj"` QProj *nn.Linear `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"` KProj *nn.Linear `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"` VProj *nn.Linear `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"` OProj *nn.Linear `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"` QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"` KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields // Computed fields
NHeads int32 NHeads int32
NKVHeads int32 NKVHeads int32
@@ -136,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
// Qwen3MLP implements Qwen3 SwiGLU MLP // Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct { type Qwen3MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"` GateProj *nn.Linear `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"` UpProj *nn.Linear `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"` DownProj *nn.Linear `weight:"down_proj"`
} }
// Forward applies the MLP // Forward applies the MLP

View File

@@ -36,8 +36,8 @@ type TransformerConfig struct {
// TimestepEmbedder creates sinusoidal timestep embeddings // TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation // Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct { type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"` Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"` Linear2 *nn.Linear `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed) FreqEmbedSize int32 // 256 (computed)
} }
@@ -74,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// XEmbedder embeds image patches to model dimension // XEmbedder embeds image patches to model dimension
type XEmbedder struct { type XEmbedder struct {
Linear nn.LinearLayer `weight:"2-1"` Linear *nn.Linear `weight:"2-1"`
} }
// Forward embeds patchified image latents // Forward embeds patchified image latents
@@ -86,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension // CapEmbedder projects caption features to model dimension
type CapEmbedder struct { type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"` Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"` Linear *nn.Linear `weight:"1"`
PadToken *mlx.Array // loaded separately at root level PadToken *mlx.Array // loaded separately at root level
} }
@@ -100,13 +100,12 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// FeedForward implements SwiGLU FFN // FeedForward implements SwiGLU FFN
type FeedForward struct { type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection W1 *nn.Linear `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection W2 *nn.Linear `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection W3 *nn.Linear `weight:"w3"` // up projection
OutDim int32 // computed from W2 OutDim int32 // computed from W2
} }
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2 // Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array { func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape() shape := x.Shape()
@@ -116,7 +115,6 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Reshape for matmul // Reshape for matmul
x = mlx.Reshape(x, B*L, D) x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x) gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate) gate = mlx.SiLU(gate)
up := ff.W3.Forward(x) up := ff.W3.Forward(x)
@@ -128,69 +126,17 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Attention implements multi-head attention with QK norm // Attention implements multi-head attention with QK norm
type Attention struct { type Attention struct {
ToQ nn.LinearLayer `weight:"to_q"` ToQ *nn.Linear `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"` ToK *nn.Linear `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"` ToV *nn.Linear `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"` ToOut *nn.Linear `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"` NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights) // Computed fields
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV) NHeads int32
Fused bool `weight:"-"` // Whether to use fused QKV path HeadDim int32
// Computed fields (not loaded from weights) Dim int32
NHeads int32 `weight:"-"` Scale float32
HeadDim int32 `weight:"-"`
Dim int32 `weight:"-"`
Scale float32 `weight:"-"`
}
// FuseQKV creates a fused QKV projection by concatenating weights.
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
// Note: Fusion is skipped for quantized weights as it would require complex
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
// the ~5% fusion benefit.
func (attn *Attention) FuseQKV() {
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
return
}
// Skip fusion for quantized weights - type assert to check
toQ, qOk := attn.ToQ.(*nn.Linear)
toK, kOk := attn.ToK.(*nn.Linear)
toV, vOk := attn.ToV.(*nn.Linear)
if !qOk || !kOk || !vOk {
// One or more are QuantizedLinear, skip fusion
return
}
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
return
}
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
qWeight := toQ.Weight
kWeight := toK.Weight
vWeight := toV.Weight
// Concatenate along output dimension (axis 0)
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
// Evaluate fused weight to ensure it's materialized
mlx.Eval(fusedWeight)
// Create fused linear layer
fusedLinear := &nn.Linear{Weight: fusedWeight}
// Handle bias if present
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
mlx.Eval(fusedBias)
fusedLinear.Bias = fusedBias
}
attn.ToQKV = fusedLinear
attn.Fused = true
} }
// Forward computes attention // Forward computes attention
@@ -200,24 +146,11 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
L := shape[1] L := shape[1]
D := shape[2] D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D) xFlat := mlx.Reshape(x, B*L, D)
q := attn.ToQ.Forward(xFlat)
var q, k, v *mlx.Array k := attn.ToK.Forward(xFlat)
if attn.Fused && attn.ToQKV != nil { v := attn.ToV.Forward(xFlat)
// Fused QKV path: single matmul then split
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
// Split into Q, K, V along last dimension
// Each has shape [B*L, dim]
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
} else {
// Separate Q, K, V projections
q = attn.ToQ.Forward(xFlat)
k = attn.ToK.Forward(xFlat)
v = attn.ToV.Forward(xFlat)
}
// Reshape to [B, L, nheads, head_dim] // Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim) q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
@@ -294,7 +227,7 @@ type TransformerBlock struct {
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"` AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"` FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"` FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields // Computed fields
HasModulation bool HasModulation bool
Dim int32 Dim int32
@@ -348,8 +281,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
// FinalLayer outputs the denoised patches // FinalLayer outputs the denoised patches
type FinalLayer struct { type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim] AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels] Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output OutDim int32 // computed from Output
} }
@@ -417,11 +350,12 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers) m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers) m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights from tensor blobs with BF16 conversion
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer") weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil { if err != nil {
return fmt.Errorf("weights: %w", err) return fmt.Errorf("weights: %w", err)
} }
if err := weights.Load(0); err != nil { if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err) return fmt.Errorf("load weights: %w", err)
} }
defer weights.ReleaseAll() defer weights.ReleaseAll()
@@ -443,7 +377,7 @@ func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
func (m *Transformer) initComputedFields() { func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256 m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim() m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.CapEmbed.Norm.Eps = 1e-6 m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners { for _, block := range m.NoiseRefiners {
@@ -457,20 +391,6 @@ func (m *Transformer) initComputedFields() {
} }
} }
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
func (m *Transformer) FuseAllQKV() {
for _, block := range m.NoiseRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.ContextRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.Layers {
block.Attention.FuseQKV()
}
}
// initTransformerBlock sets computed fields on a transformer block // initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) { func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim block.Dim = cfg.Dim
@@ -484,7 +404,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim))) attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim // Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim() block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
// Set eps on all RMSNorm layers // Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps block.AttentionNorm1.Eps = cfg.NormEps
@@ -503,8 +423,6 @@ type RoPECache struct {
UnifiedSin *mlx.Array UnifiedSin *mlx.Array
ImgLen int32 ImgLen int32
CapLen int32 CapLen int32
GridH int32 // Image token grid height
GridW int32 // Image token grid width
} }
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths. // PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
@@ -538,8 +456,6 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
UnifiedSin: unifiedSin, UnifiedSin: unifiedSin,
ImgLen: imgLen, ImgLen: imgLen,
CapLen: capLen, CapLen: capLen,
GridH: hTok,
GridW: wTok,
} }
} }

View File

@@ -104,8 +104,6 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
groupSize := C / gn.NumGroups groupSize := C / gn.NumGroups
// Keep the input - we need it for slicing tiles later // Keep the input - we need it for slicing tiles later
// Track if we were the ones who kept it, so we can restore state after
wasKept := x.Kept()
mlx.Keep(x) mlx.Keep(x)
// Compute per-group mean and variance using flattened spatial dimensions // Compute per-group mean and variance using flattened spatial dimensions
@@ -207,10 +205,6 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
} }
// Clean up kept arrays // Clean up kept arrays
// Restore x's kept state - only free if we were the ones who kept it
if !wasKept {
x.Free()
}
mean.Free() mean.Free()
invStd.Free() invStd.Free()
if weightGN != nil { if weightGN != nil {
@@ -740,26 +734,18 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
h := vae.ConvIn.Forward(z) h := vae.ConvIn.Forward(z)
mlx.Eval(h) mlx.Eval(h)
prev := h
h = vae.MidBlock.Forward(h) h = vae.MidBlock.Forward(h)
prev.Free()
for _, upBlock := range vae.UpBlocks { for _, upBlock := range vae.UpBlocks {
prev = h
h = upBlock.Forward(h) h = upBlock.Forward(h)
prev.Free()
} }
prev = h prev := h
h = vae.ConvNormOut.Forward(h) h = vae.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
prev.Free()
prev = h
h = mlx.SiLU(h) h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h) h = vae.ConvOut.Forward(h)
mlx.Eval(h) mlx.Eval(h)
prev.Free()
// VAE outputs [-1, 1], convert to [0, 1] // VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5) h = mlx.MulScalar(h, 0.5)
@@ -768,6 +754,7 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
// Convert NHWC -> NCHW for output // Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2) h = mlx.Transpose(h, 0, 3, 1, 2)
prev.Free()
mlx.Eval(h) mlx.Eval(h)
return h return h

View File

@@ -26,12 +26,10 @@ type GenerateConfig struct {
Progress ProgressFunc // Optional progress callback Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug) CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching) // Layer caching options (speedup via shallow layer reuse)
TeaCache bool // TeaCache is always enabled for faster inference LayerCache bool // Enable layer caching (default: false)
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive) CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
// Fused QKV (fuse Q/K/V projections into single matmul)
FusedQKV bool // Enable fused QKV projection (default: false)
} }
// ProgressFunc is called during generation with step progress. // ProgressFunc is called during generation with step progress.
@@ -44,7 +42,6 @@ type Model struct {
TextEncoder *Qwen3TextEncoder TextEncoder *Qwen3TextEncoder
Transformer *Transformer Transformer *Transformer
VAEDecoder *VAEDecoder VAEDecoder *VAEDecoder
qkvFused bool // Track if QKV has been fused (do only once)
} }
// Load loads the Z-Image model from ollama blob storage. // Load loads the Z-Image model from ollama blob storage.
@@ -199,17 +196,13 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if cfg.CFGScale <= 0 { if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0 cfg.CFGScale = 4.0
} }
// TeaCache enabled by default if cfg.LayerCache {
cfg.TeaCache = true if cfg.CacheInterval <= 0 {
if cfg.TeaCacheThreshold <= 0 { cfg.CacheInterval = 3
cfg.TeaCacheThreshold = 0.15 }
} if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
// Enable fused QKV if requested (only fuse once) }
if cfg.FusedQKV && !m.qkvFused {
m.Transformer.FuseAllQKV()
m.qkvFused = true
fmt.Println(" Fused QKV enabled")
} }
useCFG := cfg.NegativePrompt != "" useCFG := cfg.NegativePrompt != ""
@@ -267,54 +260,12 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(ropeCache.UnifiedCos) mlx.Eval(ropeCache.UnifiedCos)
} }
// Pre-compute batched embeddings for CFG (outside the loop for efficiency) // Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var batchedEmb *mlx.Array var stepCache *cache.StepCache
if useCFG { if cfg.LayerCache {
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D] stepCache = cache.NewStepCache(cfg.CacheLayers)
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0) fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
mlx.Keep(batchedEmb) cfg.CacheLayers, cfg.CacheInterval)
mlx.Eval(batchedEmb)
}
// TeaCache for timestep-aware caching
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
var teaCache *cache.TeaCache
if cfg.TeaCache {
skipEarly := 0
if useCFG {
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
}
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
Threshold: cfg.TeaCacheThreshold,
RescaleFactor: 1.0,
SkipEarlySteps: skipEarly,
})
if useCFG {
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
} else {
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
}
}
// cleanup frees all kept arrays when we need to abort early
cleanup := func() {
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
teaCache.Free()
}
latents.Free()
} }
// Denoising loop // Denoising loop
@@ -326,7 +277,6 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if ctx != nil { if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
cleanup()
return nil, ctx.Err() return nil, ctx.Err()
default: default:
} }
@@ -339,77 +289,50 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
} }
tCurr := scheduler.Timesteps[i] tCurr := scheduler.Timesteps[i]
var noisePred *mlx.Array timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
// TeaCache: check if we should compute or reuse cached output patches := PatchifyLatents(latents, tcfg.PatchSize)
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
if shouldCompute { var output *mlx.Array
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1})) if stepCache != nil {
patches := PatchifyLatents(latents, tcfg.PatchSize) // Use layer caching for faster inference
var output *mlx.Array
if useCFG { if useCFG {
// CFG Batching: single forward pass with batch=2 posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
// Tile patches: [1, L, D] -> [2, L, D] stepCache, i, cfg.CacheInterval)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1}) // Note: CFG with layer cache shares the cache between pos/neg
// Tile timestep: [1] -> [2] // This is approximate but fast - neg prompt uses same cached shallow layers
batchedTimestep := mlx.Tile(timestep, []int32{2}) negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D]) diff := mlx.Sub(posOutput, negOutput)
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
outputShape := batchedOutput.Shape()
L := outputShape[1]
D := outputShape[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
// Convert to noise predictions (unpatchify and negate)
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
posPred = mlx.Neg(posPred)
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
negPred = mlx.Neg(negPred)
// Cache pos/neg separately for TeaCache
if teaCache != nil {
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
// Apply CFG: noisePred = neg + scale * (pos - neg)
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale) scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff) output = mlx.Add(negOutput, scaledDiff)
} else { } else {
// Non-CFG forward pass output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache) stepCache, i, cfg.CacheInterval)
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
// Update TeaCache
if teaCache != nil {
teaCache.UpdateCache(noisePred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
} }
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
// CFG mode: get cached pos/neg and compute CFG fresh
posPred, negPred := teaCache.GetCFGCached()
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
} else { } else {
// Non-CFG mode: reuse cached noise prediction // Standard forward without caching
noisePred = teaCache.GetCached() if useCFG {
fmt.Printf(" [TeaCache: reusing cached output]\n") posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
}
} }
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
oldLatents := latents oldLatents := latents
latents = scheduler.Step(noisePred, latents, i) latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents) mlx.Eval(latents)
oldLatents.Free() oldLatents.Free()
@@ -438,14 +361,8 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
ropeCache.CapSin.Free() ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free() ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free() ropeCache.UnifiedSin.Free()
if batchedEmb != nil { if stepCache != nil {
batchedEmb.Free() stepCache.Free()
}
if teaCache != nil {
hits, misses := teaCache.Stats()
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
hits, misses, float64(hits)/float64(hits+misses)*100)
teaCache.Free()
} }
// VAE decode // VAE decode

View File

@@ -10,13 +10,6 @@ type Layer interface {
Forward(x *mlx.Array) *mlx.Array Forward(x *mlx.Array) *mlx.Array
} }
// LinearLayer is an interface for linear layers (both regular and quantized).
// This allows swapping between Linear and QuantizedLinear at runtime.
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32 // Returns the output dimension of the layer
}
// Linear applies an affine transformation: y = x @ W.T + b // Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention. // Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct { type Linear struct {
@@ -56,11 +49,6 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
return mlx.Linear(x, w) return mlx.Linear(x, w)
} }
// OutputDim returns the output dimension of the linear layer.
func (l *Linear) OutputDim() int32 {
return l.Weight.Shape()[0]
}
// ToQuantized converts this Linear to a QuantizedLinear. // ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear { func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode) qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
@@ -96,13 +84,6 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
return out return out
} }
// OutputDim returns the output dimension of the quantized linear layer.
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
// The output dimension is the first dimension of the weight.
func (ql *QuantizedLinear) OutputDim() int32 {
return ql.Weight.Shape()[0]
}
// RMSNorm represents an RMS normalization layer. // RMSNorm represents an RMS normalization layer.
type RMSNorm struct { type RMSNorm struct {
Weight *mlx.Array `weight:"weight"` Weight *mlx.Array `weight:"weight"`

View File

@@ -1,22 +0,0 @@
package imagegen
import (
"io"
"strings"
)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ShouldQuantize returns true if a tensor should be quantized.
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
func ShouldQuantize(name, component string) bool {
if component == "vae" {
return false
}
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
return false
}
return strings.HasSuffix(name, ".weight")
}

View File

@@ -13,6 +13,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"path/filepath"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@@ -33,8 +34,7 @@ type Request struct {
// Response is streamed back for each progress update // Response is streamed back for each progress update
type Response struct { type Response struct {
Content string `json:"content,omitempty"` Content string `json:"content"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"` Done bool `json:"done"`
} }
@@ -191,10 +191,10 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
// Encode image as base64 PNG // Save image
imageData, err := imagegen.EncodeImageBase64(img) outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
if err != nil { if err := imagegen.SaveImage(img, outPath); err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true} resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
data, _ := json.Marshal(resp) data, _ := json.Marshal(resp)
w.Write(data) w.Write(data)
w.Write([]byte("\n")) w.Write([]byte("\n"))
@@ -204,12 +204,11 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
// Free the generated image array and clean up MLX state // Free the generated image array and clean up MLX state
img.Free() img.Free()
mlx.ClearCache() mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data // Send final response
resp := Response{ resp := Response{
Image: imageData, Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
Done: true, Done: true,
} }
data, _ := json.Marshal(resp) data, _ := json.Marshal(resp)
w.Write(data) w.Write(data)

View File

@@ -8,7 +8,6 @@ import (
"strings" "strings"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
) )
// WeightSource is an interface for loading weights. // WeightSource is an interface for loading weights.
@@ -103,22 +102,6 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
} }
} }
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
if !hasTag {
continue // no tag = skip
}
layer, err := LoadLinearLayer(weights, fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath+": "+err.Error())
}
continue
}
fieldVal.Set(reflect.ValueOf(layer))
continue
}
// Handle by kind // Handle by kind
switch fieldVal.Kind() { switch fieldVal.Kind() {
case reflect.Ptr: case reflect.Ptr:
@@ -193,64 +176,3 @@ func joinPath(prefix, suffix string) string {
} }
return prefix + "." + suffix return prefix + "." + suffix
} }
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
}
scales, err := weights.GetTensor(scalePath)
if err != nil {
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
if mlx.MetalIsAvailable() {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: 32,
Bits: 8,
Mode: "affine",
}, nil
}
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
return nn.NewLinear(dequantized, bias), nil
}
// Load as regular Linear
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
return nn.NewLinear(weight, bias), nil
}

View File

@@ -14,9 +14,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -48,8 +46,7 @@ type completionRequest struct {
// completionResponse is received from the subprocess // completionResponse is received from the subprocess
type completionResponse struct { type completionResponse struct {
Content string `json:"content,omitempty"` Content string `json:"content"`
Image string `json:"image,omitempty"`
Done bool `json:"done"` Done bool `json:"done"`
} }
@@ -72,7 +69,7 @@ func NewServer(modelName string) (*Server, error) {
port = rand.Intn(65535-49152) + 49152 port = rand.Intn(65535-49152) + 49152
} }
// Get the ollama-mlx executable path (in same directory as current executable) // Get the ollama executable path
exe, err := os.Executable() exe, err := os.Executable()
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err) return nil, fmt.Errorf("unable to lookup executable path: %w", err)
@@ -80,42 +77,11 @@ func NewServer(modelName string) (*Server, error) {
if eval, err := filepath.EvalSymlinks(exe); err == nil { if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval exe = eval
} }
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port> // Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port)) cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ() cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
s := &Server{ s := &Server{
cmd: cmd, cmd: cmd,
port: port, port: port,
@@ -146,7 +112,7 @@ func NewServer(modelName string) (*Server, error) {
} }
}() }()
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port) slog.Info("starting image runner subprocess", "model", modelName, "port", port)
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start image runner: %w", err) return nil, fmt.Errorf("failed to start image runner: %w", err)
} }
@@ -284,23 +250,15 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
return fmt.Errorf("completion request failed: %d", resp.StatusCode) return fmt.Errorf("completion request failed: %d", resp.StatusCode)
} }
// Stream responses - use large buffer for base64 image data // Stream responses
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() { for scanner.Scan() {
var cresp completionResponse var cresp completionResponse
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil { if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
continue continue
} }
content := cresp.Content
// If this is the final response with an image, encode it in the content
if cresp.Done && cresp.Image != "" {
content = "IMAGE_BASE64:" + cresp.Image
}
fn(llm.CompletionResponse{ fn(llm.CompletionResponse{
Content: content, Content: cresp.Content,
Done: cresp.Done, Done: cresp.Done,
}) })
if cresp.Done { if cresp.Done {

View File

@@ -45,33 +45,24 @@ func download(ctx context.Context, opts DownloadOptions) error {
return nil return nil
} }
// Calculate total from all blobs (for accurate progress reporting on resume) // Filter existing
var total int64
for _, b := range opts.Blobs {
total += b.Size
}
// Filter out already-downloaded blobs and track completed bytes
var blobs []Blob var blobs []Blob
var alreadyCompleted int64 var total int64
for _, b := range opts.Blobs { for _, b := range opts.Blobs {
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size { if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
if opts.Logger != nil { if opts.Logger != nil {
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size) opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
} }
alreadyCompleted += b.Size
continue continue
} }
blobs = append(blobs, b) blobs = append(blobs, b)
total += b.Size
} }
if len(blobs) == 0 { if len(blobs) == 0 {
return nil return nil
} }
token := opts.Token token := opts.Token
progress := newProgressTracker(total, opts.Progress)
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
d := &downloader{ d := &downloader{
client: cmp.Or(opts.Client, defaultClient), client: cmp.Or(opts.Client, defaultClient),
baseURL: opts.BaseURL, baseURL: opts.BaseURL,
@@ -81,7 +72,7 @@ func download(ctx context.Context, opts DownloadOptions) error {
getToken: opts.GetToken, getToken: opts.GetToken,
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent), userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout), stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
progress: progress, progress: newProgressTracker(total, opts.Progress),
speeds: &speedTracker{}, speeds: &speedTracker{},
logger: opts.Logger, logger: opts.Logger,
} }

View File

@@ -110,6 +110,8 @@ var defaultClient = &http.Client{
MaxIdleConnsPerHost: 100, MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
}, },
Timeout: 5 * time.Minute,
// Don't follow redirects automatically - we handle them manually
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse return http.ErrUseLastResponse
}, },

View File

@@ -284,83 +284,6 @@ func TestDownloadSkipsExisting(t *testing.T) {
} }
} }
func TestDownloadResumeProgressTotal(t *testing.T) {
// Test that when resuming a download with some blobs already present:
// 1. Total reflects ALL blob sizes (not just remaining)
// 2. Completed starts at the size of already-downloaded blobs
serverDir := t.TempDir()
blob1, data1 := createTestBlob(t, serverDir, 1000)
blob2, data2 := createTestBlob(t, serverDir, 2000)
blob3, data3 := createTestBlob(t, serverDir, 3000)
// Pre-populate client with blob1 and blob2 (simulating partial download)
clientDir := t.TempDir()
for _, b := range []struct {
blob Blob
data []byte
}{{blob1, data1}, {blob2, data2}} {
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, b.data, 0o644); err != nil {
t.Fatal(err)
}
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
var firstCompleted, firstTotal int64
var gotFirstProgress bool
var mu sync.Mutex
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob1, blob2, blob3},
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 1,
Progress: func(completed, total int64) {
mu.Lock()
defer mu.Unlock()
if !gotFirstProgress {
firstCompleted = completed
firstTotal = total
gotFirstProgress = true
}
},
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Total should be sum of ALL blobs, not just blob3
expectedTotal := blob1.Size + blob2.Size + blob3.Size
if firstTotal != expectedTotal {
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
}
// First progress call should show already-completed bytes from blob1+blob2
expectedCompleted := blob1.Size + blob2.Size
if firstCompleted < expectedCompleted {
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
}
// Verify blob3 was downloaded
verifyBlob(t, clientDir, blob3, data3)
}
func TestDownloadDigestMismatch(t *testing.T) { func TestDownloadDigestMismatch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return wrong data // Return wrong data

View File

@@ -54,16 +54,6 @@ func (r *Registry) RegisterBash() {
r.Register(&BashTool{}) r.Register(&BashTool{})
} }
// RegisterWebSearch adds the web search tool to the registry.
func (r *Registry) RegisterWebSearch() {
r.Register(&WebSearchTool{})
}
// RegisterWebFetch adds the web fetch tool to the registry.
func (r *Registry) RegisterWebFetch() {
r.Register(&WebFetchTool{})
}
// Get retrieves a tool by name. // Get retrieves a tool by name.
func (r *Registry) Get(name string) (Tool, bool) { func (r *Registry) Get(name string) (Tool, bool) {
tool, ok := r.tools[name] tool, ok := r.tools[name]

View File

@@ -1,162 +0,0 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
)
const (
webFetchAPI = "https://ollama.com/api/web_fetch"
webFetchTimeout = 30 * time.Second
)
// ErrWebFetchAuthRequired is returned when web fetch requires authentication
var ErrWebFetchAuthRequired = errors.New("web fetch requires authentication")
// WebFetchTool implements web page fetching using Ollama's hosted API.
type WebFetchTool struct{}
// Name returns the tool name.
func (w *WebFetchTool) Name() string {
return "web_fetch"
}
// Description returns a description of the tool.
func (w *WebFetchTool) Description() string {
return "Fetch and extract text content from a web page. Use this to read the full content of a URL found in search results or provided by the user."
}
// Schema returns the tool's parameter schema.
func (w *WebFetchTool) Schema() api.ToolFunction {
props := api.NewToolPropertiesMap()
props.Set("url", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The URL to fetch and extract content from",
})
return api.ToolFunction{
Name: w.Name(),
Description: w.Description(),
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"url"},
},
}
}
// webFetchRequest is the request body for the web fetch API.
type webFetchRequest struct {
URL string `json:"url"`
}
// webFetchResponse is the response from the web fetch API.
type webFetchResponse struct {
Title string `json:"title"`
Content string `json:"content"`
Links []string `json:"links,omitempty"`
}
// Execute fetches content from a web page.
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
func (w *WebFetchTool) Execute(args map[string]any) (string, error) {
urlStr, ok := args["url"].(string)
if !ok || urlStr == "" {
return "", fmt.Errorf("url parameter is required")
}
// Validate URL
if _, err := url.Parse(urlStr); err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
}
// Prepare request
reqBody := webFetchRequest{
URL: urlStr,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshaling request: %w", err)
}
// Parse URL and add timestamp for signing
fetchURL, err := url.Parse(webFetchAPI)
if err != nil {
return "", fmt.Errorf("parsing fetch URL: %w", err)
}
q := fetchURL.Query()
q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
fetchURL.RawQuery = q.Encode()
// Sign the request using Ollama key (~/.ollama/id_ed25519)
ctx := context.Background()
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, fetchURL.RequestURI())
signature, err := auth.Sign(ctx, data)
if err != nil {
return "", fmt.Errorf("signing request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fetchURL.String(), bytes.NewBuffer(jsonBody))
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if signature != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
}
// Send request
client := &http.Client{Timeout: webFetchTimeout}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("sending request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized {
return "", ErrWebFetchAuthRequired
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("web fetch API returned status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var fetchResp webFetchResponse
if err := json.Unmarshal(body, &fetchResp); err != nil {
return "", fmt.Errorf("parsing response: %w", err)
}
// Format result
var sb strings.Builder
if fetchResp.Title != "" {
sb.WriteString(fmt.Sprintf("Title: %s\n\n", fetchResp.Title))
}
if fetchResp.Content != "" {
sb.WriteString("Content:\n")
sb.WriteString(fetchResp.Content)
} else {
sb.WriteString("No content could be extracted from the page.")
}
return sb.String(), nil
}