mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
19 Commits
v0.14.0-rc
...
parth/decr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b2abfb433 | ||
|
|
805ed4644c | ||
|
|
e4b488a7b5 | ||
|
|
98079ddd79 | ||
|
|
d70942f47b | ||
|
|
58e4701557 | ||
|
|
dbf47ee55a | ||
|
|
af7ea6e96e | ||
|
|
8f1e0140e7 | ||
|
|
35c3c9e3c2 | ||
|
|
d06acbcb19 | ||
|
|
9667c2282f | ||
|
|
a937a68317 | ||
|
|
2185112d84 | ||
|
|
91926601dc | ||
|
|
361d6c16c2 | ||
|
|
7e2496e88e | ||
|
|
5b84e29882 | ||
|
|
7cc2a653f2 |
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -13,7 +13,7 @@ body:
|
|||||||
id: logs
|
id: logs
|
||||||
attributes:
|
attributes:
|
||||||
label: Relevant log output
|
label: Relevant log output
|
||||||
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
|
||||||
render: shell
|
render: shell
|
||||||
validations:
|
validations:
|
||||||
required: false
|
required: false
|
||||||
|
|||||||
6
.github/workflows/release.yaml
vendored
6
.github/workflows/release.yaml
vendored
@@ -372,13 +372,17 @@ jobs:
|
|||||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||||
cache-to: type=inline
|
cache-to: type=inline
|
||||||
|
- name: Deduplicate CUDA libraries
|
||||||
|
run: |
|
||||||
|
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
- run: |
|
- run: |
|
||||||
for COMPONENT in bin/* lib/ollama/*; do
|
for COMPONENT in bin/* lib/ollama/*; do
|
||||||
case "$COMPONENT" in
|
case "$COMPONENT" in
|
||||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
|
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
|
|||||||
@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
|||||||
set(GGML_CPU_ALL_VARIANTS ON)
|
set(GGML_CPU_ALL_VARIANTS ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
if(APPLE)
|
||||||
set(CMAKE_BUILD_RPATH "@loader_path")
|
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||||
|
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||||
@@ -196,6 +197,14 @@ if(MLX_ENGINE)
|
|||||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Install the Metal library for macOS arm64 (must be colocated with the binary)
|
||||||
|
# Metal backend is only built for arm64, not x86_64
|
||||||
|
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||||
|
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
|
||||||
|
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||||
|
COMPONENT MLX)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||||
if(CUDAToolkit_FOUND)
|
if(CUDAToolkit_FOUND)
|
||||||
file(GLOB CUDART_LIBS
|
file(GLOB CUDART_LIBS
|
||||||
|
|||||||
@@ -161,6 +161,9 @@ ARG GOFLAGS="'-ldflags=-w -s'"
|
|||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
ARG CGO_CXXFLAGS
|
ARG CGO_CXXFLAGS
|
||||||
|
RUN mkdir -p dist/bin
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
|
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 8 * format.MegaByte
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf io.Reader
|
var buf io.Reader
|
||||||
|
|||||||
34
cmd/cmd.go
34
cmd/cmd.go
@@ -100,7 +100,8 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if filename == "" {
|
if filename == "" {
|
||||||
// No Modelfile found - check if current directory is an image gen model
|
// No Modelfile found - check if current directory is an image gen model
|
||||||
if imagegen.IsTensorModelDir(".") {
|
if imagegen.IsTensorModelDir(".") {
|
||||||
return imagegenclient.CreateModel(args[0], ".", p)
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
|
return imagegenclient.CreateModel(args[0], ".", quantize, p)
|
||||||
}
|
}
|
||||||
reader = strings.NewReader("FROM .\n")
|
reader = strings.NewReader("FROM .\n")
|
||||||
} else {
|
} else {
|
||||||
@@ -464,14 +465,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
name := args[0]
|
name := args[0]
|
||||||
|
|
||||||
// Check if this is a known image generation model (skip Show/Pull)
|
|
||||||
if imagegen.HasTensorLayers(name) {
|
|
||||||
if opts.Prompt == "" && !interactive {
|
|
||||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
|
||||||
}
|
|
||||||
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
|
||||||
}
|
|
||||||
|
|
||||||
info, err := func() (*api.ShowResponse, error) {
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
showReq := &api.ShowRequest{Name: name}
|
showReq := &api.ShowRequest{Name: name}
|
||||||
info, err := client.Show(cmd.Context(), showReq)
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
@@ -533,9 +526,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is an image generation model
|
||||||
|
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||||
|
if opts.Prompt == "" && !interactive {
|
||||||
|
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||||
|
}
|
||||||
|
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
|
||||||
|
}
|
||||||
|
|
||||||
// Check for experimental flag
|
// Check for experimental flag
|
||||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
|
||||||
|
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
@@ -565,7 +567,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
// Use experimental agent loop with tools
|
// Use experimental agent loop with tools
|
||||||
if isExperimental {
|
if isExperimental {
|
||||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
|
||||||
}
|
}
|
||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
@@ -671,7 +673,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
bar, ok := bars[resp.Digest]
|
bar, ok := bars[resp.Digest]
|
||||||
if !ok {
|
if !ok {
|
||||||
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
msg := resp.Status
|
||||||
|
if msg == "" {
|
||||||
|
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||||
bars[resp.Digest] = bar
|
bars[resp.Digest] = bar
|
||||||
p.Add(resp.Digest, bar)
|
p.Add(resp.Digest, bar)
|
||||||
}
|
}
|
||||||
@@ -837,11 +843,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ShowHandler(cmd *cobra.Command, args []string) error {
|
func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||||
// Check if this is an image generation model
|
|
||||||
if imagegen.HasTensorLayers(args[0]) {
|
|
||||||
return imagegen.Show(args[0], os.Stdout)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1786,6 +1787,7 @@ func NewCLI() *cobra.Command {
|
|||||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||||
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
|
||||||
|
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
|
||||||
|
|
||||||
// Image generation flags (width, height, steps, seed, etc.)
|
// Image generation flags (width, height, steps, seed, etc.)
|
||||||
imagegen.RegisterFlags(runCmd)
|
imagegen.RegisterFlags(runCmd)
|
||||||
|
|||||||
@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShowInfoImageGen(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "ZImagePipeline",
|
||||||
|
ParameterSize: "10.3B",
|
||||||
|
QuantizationLevel: "FP8",
|
||||||
|
},
|
||||||
|
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||||
|
Requires: "0.14.0",
|
||||||
|
}, false, &b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := " Model\n" +
|
||||||
|
" architecture ZImagePipeline \n" +
|
||||||
|
" parameters 10.3B \n" +
|
||||||
|
" quantization FP8 \n" +
|
||||||
|
" requires 0.14.0 \n" +
|
||||||
|
"\n" +
|
||||||
|
" Capabilities\n" +
|
||||||
|
" image \n" +
|
||||||
|
"\n"
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPushProgressMessage(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status string
|
||||||
|
digest string
|
||||||
|
wantMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "uses status when provided",
|
||||||
|
status: "uploading model",
|
||||||
|
digest: "sha256:abc123456789def",
|
||||||
|
wantMsg: "uploading model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "falls back to digest when status empty",
|
||||||
|
status: "",
|
||||||
|
digest: "sha256:abc123456789def",
|
||||||
|
wantMsg: "pushing abc123456789...",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles short digest gracefully",
|
||||||
|
status: "",
|
||||||
|
digest: "sha256:abc",
|
||||||
|
wantMsg: "pushing sha256:abc...",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
msg := tt.status
|
||||||
|
if msg == "" {
|
||||||
|
if len(tt.digest) >= 19 {
|
||||||
|
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
|
||||||
|
} else {
|
||||||
|
msg = fmt.Sprintf("pushing %s...", tt.digest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if msg != tt.wantMsg {
|
||||||
|
t.Errorf("got %q, want %q", msg, tt.wantMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||||
// Test that modifications to original don't affect copy
|
// Test that modifications to original don't affect copy
|
||||||
originalThink := &api.ThinkValue{Value: "original"}
|
originalThink := &api.ThinkValue{Value: "original"}
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
# Troubleshooting
|
|
||||||
|
|
||||||
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)
|
|
||||||
@@ -118,6 +118,9 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||||
|
c.Set("relax_thinking", true)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
|||||||
@@ -582,3 +582,26 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var flagSet bool
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(AnthropicMessagesMiddleware())
|
||||||
|
router.POST("/v1/messages", func(c *gin.Context) {
|
||||||
|
_, flagSet = c.Get("relax_thinking")
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
if !flagSet {
|
||||||
|
t.Error("expected relax_thinking flag to be set in context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ _build_darwin() {
|
|||||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||||
fi
|
fi
|
||||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||||
done
|
done
|
||||||
}
|
}
|
||||||
@@ -82,19 +82,19 @@ _sign_darwin() {
|
|||||||
status "Creating universal binary..."
|
status "Creating universal binary..."
|
||||||
mkdir -p dist/darwin
|
mkdir -p dist/darwin
|
||||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||||
chmod +x dist/darwin/ollama
|
chmod +x dist/darwin/ollama
|
||||||
chmod +x dist/darwin/imagegen
|
chmod +x dist/darwin/ollama-mlx
|
||||||
|
|
||||||
if [ -n "$APPLE_IDENTITY" ]; then
|
if [ -n "$APPLE_IDENTITY" ]; then
|
||||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||||
done
|
done
|
||||||
|
|
||||||
# create a temporary zip for notarization
|
# create a temporary zip for notarization
|
||||||
TEMP=$(mktemp -u).zip
|
TEMP=$(mktemp -u).zip
|
||||||
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
|
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
|
||||||
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
|
||||||
rm -f "$TEMP"
|
rm -f "$TEMP"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -154,23 +154,25 @@ _build_macapp() {
|
|||||||
mkdir -p dist/Ollama.app/Contents/Resources
|
mkdir -p dist/Ollama.app/Contents/Resources
|
||||||
if [ -d dist/darwin-amd64 ]; then
|
if [ -d dist/darwin-amd64 ]; then
|
||||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||||
done
|
done
|
||||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||||
|
# Copy MLX metallib (architecture-independent, just use arm64 version)
|
||||||
|
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
|
||||||
else
|
else
|
||||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||||
fi
|
fi
|
||||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||||
|
|
||||||
# Sign
|
# Sign
|
||||||
if [ -n "$APPLE_IDENTITY" ]; then
|
if [ -n "$APPLE_IDENTITY" ]; then
|
||||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||||
done
|
done
|
||||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||||
@@ -178,11 +180,11 @@ _build_macapp() {
|
|||||||
|
|
||||||
rm -f dist/Ollama-darwin.zip
|
rm -f dist/Ollama-darwin.zip
|
||||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||||
|
|
||||||
# Notarize and Staple
|
# Notarize and Staple
|
||||||
if [ -n "$APPLE_IDENTITY" ]; then
|
if [ -n "$APPLE_IDENTITY" ]; then
|
||||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||||
rm -f dist/Ollama-darwin.zip
|
rm -f dist/Ollama-darwin.zip
|
||||||
$(xcrun -f stapler) staple dist/Ollama.app
|
$(xcrun -f stapler) staple dist/Ollama.app
|
||||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||||
@@ -206,7 +208,7 @@ _build_macapp() {
|
|||||||
rm -f dist/rw*.dmg
|
rm -f dist/rw*.dmg
|
||||||
|
|
||||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
|
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
|
||||||
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||||
$(xcrun -f stapler) staple dist/Ollama.dmg
|
$(xcrun -f stapler) staple dist/Ollama.dmg
|
||||||
else
|
else
|
||||||
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"
|
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"
|
||||||
|
|||||||
@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
|||||||
.
|
.
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
|
||||||
deduplicate_cuda_libs() {
|
|
||||||
local base_dir="$1"
|
|
||||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
|
||||||
|
|
||||||
# Find all mlx_cuda_* directories
|
|
||||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
|
||||||
[ -d "${mlx_dir}" ] || continue
|
|
||||||
|
|
||||||
# Extract CUDA version (e.g., v12, v13)
|
|
||||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
|
||||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
|
||||||
|
|
||||||
# Skip if corresponding cuda_* directory doesn't exist
|
|
||||||
[ -d "${cuda_dir}" ] || continue
|
|
||||||
|
|
||||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
|
||||||
|
|
||||||
# Find all .so* files in mlx directory
|
|
||||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
|
||||||
filename=$(basename "${mlx_file}")
|
|
||||||
cuda_file="${cuda_dir}/${filename}"
|
|
||||||
|
|
||||||
# Skip if file doesn't exist in cuda directory
|
|
||||||
[ -f "${cuda_file}" ] || continue
|
|
||||||
|
|
||||||
# Compare checksums
|
|
||||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
|
||||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
|
||||||
|
|
||||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
|
||||||
echo " Deduplicating ${filename}"
|
|
||||||
# Calculate relative path from mlx_dir to cuda_dir
|
|
||||||
rel_path="../cuda_${cuda_version}/${filename}"
|
|
||||||
rm -f "${mlx_file}"
|
|
||||||
ln -s "${rel_path}" "${mlx_file}"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
done
|
|
||||||
}
|
|
||||||
|
|
||||||
# Run deduplication for each platform output directory
|
# Run deduplication for each platform output directory
|
||||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
|
||||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
|
||||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||||
deduplicate_cuda_libs "./dist"
|
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# buildx behavior changes for single vs. multiplatform
|
# buildx behavior changes for single vs. multiplatform
|
||||||
|
|||||||
60
scripts/deduplicate_cuda_libs.sh
Executable file
60
scripts/deduplicate_cuda_libs.sh
Executable file
@@ -0,0 +1,60 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
#
|
||||||
|
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||||
|
# This script finds identical .so* files in mlx_cuda_* directories that exist
|
||||||
|
# in corresponding cuda_* directories and replaces them with symlinks.
|
||||||
|
#
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
if [ $# -eq 0 ]; then
|
||||||
|
echo "ERROR: No directory specified" >&2
|
||||||
|
echo "Usage: $0 <base_directory>" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
base_dir="$1"
|
||||||
|
|
||||||
|
if [ ! -d "${base_dir}" ]; then
|
||||||
|
echo "ERROR: Directory ${base_dir} does not exist" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||||
|
|
||||||
|
# Find all mlx_cuda_* directories
|
||||||
|
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||||
|
[ -d "${mlx_dir}" ] || continue
|
||||||
|
|
||||||
|
# Extract CUDA version (e.g., v12, v13)
|
||||||
|
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||||
|
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||||
|
|
||||||
|
# Skip if corresponding cuda_* directory doesn't exist
|
||||||
|
[ -d "${cuda_dir}" ] || continue
|
||||||
|
|
||||||
|
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||||
|
|
||||||
|
# Find all .so* files in mlx directory
|
||||||
|
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||||
|
filename=$(basename "${mlx_file}")
|
||||||
|
cuda_file="${cuda_dir}/${filename}"
|
||||||
|
|
||||||
|
# Skip if file doesn't exist in cuda directory
|
||||||
|
[ -f "${cuda_file}" ] || continue
|
||||||
|
|
||||||
|
# Compare checksums
|
||||||
|
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||||
|
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||||
|
|
||||||
|
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||||
|
echo " Deduplicating ${filename}"
|
||||||
|
# Calculate relative path from mlx_dir to cuda_dir
|
||||||
|
rel_path="../cuda_${cuda_version}/${filename}"
|
||||||
|
rm -f "${mlx_file}"
|
||||||
|
ln -s "${rel_path}" "${mlx_file}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Deduplication complete"
|
||||||
@@ -95,11 +95,48 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
numDownloadParts = 16
|
// numDownloadParts is the default number of concurrent download parts for standard downloads
|
||||||
|
numDownloadParts = 16
|
||||||
|
// numHFDownloadParts is the reduced number of concurrent download parts for HuggingFace
|
||||||
|
// downloads to avoid triggering rate limits (HTTP 429 errors). See GitHub issue #13297.
|
||||||
|
numHFDownloadParts = 4
|
||||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isHuggingFaceURL returns true if the URL is from a HuggingFace domain.
|
||||||
|
// This includes:
|
||||||
|
// - huggingface.co (main domain)
|
||||||
|
// - *.huggingface.co (subdomains like cdn-lfs.huggingface.co)
|
||||||
|
// - hf.co (shortlink domain)
|
||||||
|
// - *.hf.co (CDN domains like cdn-lfs.hf.co, cdn-lfs3.hf.co)
|
||||||
|
func isHuggingFaceURL(u *url.URL) bool {
|
||||||
|
if u == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
host := strings.ToLower(u.Hostname())
|
||||||
|
return host == "huggingface.co" ||
|
||||||
|
strings.HasSuffix(host, ".huggingface.co") ||
|
||||||
|
host == "hf.co" ||
|
||||||
|
strings.HasSuffix(host, ".hf.co")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNumDownloadParts returns the number of concurrent download parts to use
|
||||||
|
// for the given URL. HuggingFace URLs use reduced concurrency (default 4) to
|
||||||
|
// avoid triggering rate limits. This can be overridden via the OLLAMA_HF_CONCURRENCY
|
||||||
|
// environment variable. For non-HuggingFace URLs, returns the standard concurrency (16).
|
||||||
|
func getNumDownloadParts(u *url.URL) int {
|
||||||
|
if isHuggingFaceURL(u) {
|
||||||
|
if v := os.Getenv("OLLAMA_HF_CONCURRENCY"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return numHFDownloadParts
|
||||||
|
}
|
||||||
|
return numDownloadParts
|
||||||
|
}
|
||||||
|
|
||||||
func (p *blobDownloadPart) Name() string {
|
func (p *blobDownloadPart) Name() string {
|
||||||
return strings.Join([]string{
|
return strings.Join([]string{
|
||||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||||
@@ -271,7 +308,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
}
|
}
|
||||||
|
|
||||||
g, inner := errgroup.WithContext(ctx)
|
g, inner := errgroup.WithContext(ctx)
|
||||||
g.SetLimit(numDownloadParts)
|
concurrency := getNumDownloadParts(directURL)
|
||||||
|
if concurrency != numDownloadParts {
|
||||||
|
slog.Info(fmt.Sprintf("using reduced concurrency (%d) for HuggingFace download", concurrency))
|
||||||
|
}
|
||||||
|
g.SetLimit(concurrency)
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
part := b.Parts[i]
|
part := b.Parts[i]
|
||||||
if part.Completed.Load() == part.Size {
|
if part.Completed.Load() == part.Size {
|
||||||
|
|||||||
194
server/download_test.go
Normal file
194
server/download_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsHuggingFaceURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil url",
|
||||||
|
url: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface.co main domain",
|
||||||
|
url: "https://huggingface.co/some/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cdn-lfs.huggingface.co subdomain",
|
||||||
|
url: "https://cdn-lfs.huggingface.co/repos/abc/123",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cdn-lfs3.hf.co CDN domain",
|
||||||
|
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hf.co shortlink domain",
|
||||||
|
url: "https://hf.co/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase HuggingFace domain",
|
||||||
|
url: "https://HUGGINGFACE.CO/model",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case HF domain",
|
||||||
|
url: "https://Cdn-Lfs.HF.Co/repos",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ollama registry",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "github.com",
|
||||||
|
url: "https://github.com/ollama/ollama",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fake huggingface domain",
|
||||||
|
url: "https://nothuggingface.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fake hf domain",
|
||||||
|
url: "https://nothf.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface in path not host",
|
||||||
|
url: "https://example.com/huggingface.co/model",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var u *url.URL
|
||||||
|
if tc.url != "" {
|
||||||
|
var err error
|
||||||
|
u, err = url.Parse(tc.url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
got := isHuggingFaceURL(u)
|
||||||
|
assert.Equal(t, tc.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNumDownloadParts(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
envValue string
|
||||||
|
expected int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil url returns default",
|
||||||
|
url: "",
|
||||||
|
envValue: "",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "nil URL should return standard concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ollama registry returns default",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
envValue: "",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "Ollama registry should use standard concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface returns reduced default",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "HuggingFace should use reduced concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hf.co CDN returns reduced default",
|
||||||
|
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
||||||
|
envValue: "",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "HuggingFace CDN should use reduced concurrency",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with env override",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "2",
|
||||||
|
expected: 2,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY should override default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with higher env override",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "8",
|
||||||
|
expected: 8,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY can be set higher than default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (non-numeric)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "invalid",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Invalid OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (zero)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "0",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Zero OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "huggingface with invalid env (negative)",
|
||||||
|
url: "https://huggingface.co/model/repo",
|
||||||
|
envValue: "-1",
|
||||||
|
expected: numHFDownloadParts,
|
||||||
|
description: "Negative OLLAMA_HF_CONCURRENCY should fall back to default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-huggingface ignores env",
|
||||||
|
url: "https://registry.ollama.ai/v2/library/llama3",
|
||||||
|
envValue: "2",
|
||||||
|
expected: numDownloadParts,
|
||||||
|
description: "OLLAMA_HF_CONCURRENCY should not affect non-HF URLs",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Set or clear the environment variable
|
||||||
|
if tc.envValue != "" {
|
||||||
|
t.Setenv("OLLAMA_HF_CONCURRENCY", tc.envValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
var u *url.URL
|
||||||
|
if tc.url != "" {
|
||||||
|
var err error
|
||||||
|
u, err = url.Parse(tc.url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
got := getNumDownloadParts(u)
|
||||||
|
assert.Equal(t, tc.expected, got, tc.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manifest) RemoveLayers() error {
|
func (m *Manifest) RemoveLayers() error {
|
||||||
for _, layer := range append(m.Layers, m.Config) {
|
ms, err := Manifests(true)
|
||||||
if layer.Digest != "" {
|
if err != nil {
|
||||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
return err
|
||||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
}
|
||||||
} else if err != nil {
|
|
||||||
return err
|
// Build set of digests still in use by other manifests
|
||||||
|
inUse := make(map[string]struct{})
|
||||||
|
for _, other := range ms {
|
||||||
|
for _, layer := range append(other.Layers, other.Config) {
|
||||||
|
if layer.Digest != "" {
|
||||||
|
inUse[layer.Digest] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove layers not used by any other manifest
|
||||||
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
|
if layer.Digest == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, used := inUse[layer.Digest]; used {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
blob, err := GetBlobsPath(layer.Digest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
|
||||||
|
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1124,6 +1124,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
QuantizationLevel: m.Config.FileType,
|
QuantizationLevel: m.Config.FileType,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For image generation models, populate details from imagegen package
|
||||||
|
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||||
|
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||||
|
modelDetails.Family = info.Architecture
|
||||||
|
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||||
|
modelDetails.QuantizationLevel = info.Quantization
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if req.System != "" {
|
if req.System != "" {
|
||||||
m.System = req.System
|
m.System = req.System
|
||||||
}
|
}
|
||||||
@@ -1206,6 +1215,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -2059,8 +2072,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if req.Think != nil && req.Think.Bool() {
|
if req.Think != nil && req.Think.Bool() {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
||||||
return
|
if _, ok := c.Get("relax_thinking"); ok {
|
||||||
|
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
|
||||||
|
req.Think = nil
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
48
x/README.md
48
x/README.md
@@ -1,24 +1,50 @@
|
|||||||
# Experimental Features
|
# Experimental Features
|
||||||
|
|
||||||
## MLX Backend
|
## MLX Backend
|
||||||
|
|
||||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||||
|
|
||||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
|
||||||
|
|
||||||
```
|
### Building ollama-mlx
|
||||||
|
|
||||||
|
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
|
||||||
|
|
||||||
|
#### macOS (Apple Silicon and Intel)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build MLX backend libraries
|
||||||
cmake --preset MLX
|
cmake --preset MLX
|
||||||
cmake --build --preset MLX --parallel
|
cmake --build --preset MLX --parallel
|
||||||
cmake --install --component MLX
|
cmake --install build --component MLX
|
||||||
go build -tags mlx .
|
|
||||||
|
# Build ollama-mlx binary
|
||||||
|
go build -tags mlx -o ollama-mlx .
|
||||||
```
|
```
|
||||||
|
|
||||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
#### Linux (CUDA)
|
||||||
|
|
||||||
|
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build MLX backend libraries with CUDA support
|
||||||
|
cmake --preset 'MLX CUDA 13'
|
||||||
|
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||||
|
cmake --install build --component MLX
|
||||||
|
|
||||||
|
# Build ollama-mlx binary
|
||||||
|
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
|
||||||
|
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
|
||||||
|
go build -tags mlx -o ollama-mlx .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using build scripts
|
||||||
|
|
||||||
|
The build scripts automatically create the `ollama-mlx` binary:
|
||||||
|
|
||||||
|
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
|
||||||
|
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
|
||||||
|
|
||||||
## Image Generation
|
## Image Generation
|
||||||
|
|
||||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.
|
||||||
|
|
||||||
```
|
|
||||||
go build -o imagegen ./x/imagegen/cmd/engine
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ var optionLabels = []string{
|
|||||||
var toolDisplayNames = map[string]string{
|
var toolDisplayNames = map[string]string{
|
||||||
"bash": "Bash",
|
"bash": "Bash",
|
||||||
"web_search": "Web Search",
|
"web_search": "Web Search",
|
||||||
|
"web_fetch": "Web Fetch",
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolDisplayName returns the human-readable display name for a tool.
|
// ToolDisplayName returns the human-readable display name for a tool.
|
||||||
@@ -565,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For web fetch, show URL and internet notice
|
||||||
|
if toolName == "web_fetch" {
|
||||||
|
if url, ok := args["url"].(string); ok {
|
||||||
|
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||||
|
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||||
|
sb.WriteString("Uses internet via ollama.com")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Generic display
|
// Generic display
|
||||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
@@ -1017,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if toolName == "web_fetch" {
|
||||||
|
if url, ok := args["url"].(string); ok {
|
||||||
|
// Truncate long URLs
|
||||||
|
if len(url) > 50 {
|
||||||
|
url = url[:47] + "..."
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
365
x/cmd/run.go
365
x/cmd/run.go
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -24,6 +25,14 @@ import (
|
|||||||
"github.com/ollama/ollama/x/tools"
|
"github.com/ollama/ollama/x/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MultilineState tracks the state of multiline input
|
||||||
|
type MultilineState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
MultilineNone MultilineState = iota
|
||||||
|
MultilineSystem
|
||||||
|
)
|
||||||
|
|
||||||
// Tool output capping constants
|
// Tool output capping constants
|
||||||
const (
|
const (
|
||||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||||
@@ -130,6 +139,7 @@ type RunOptions struct {
|
|||||||
KeepAlive *api.Duration
|
KeepAlive *api.Duration
|
||||||
Think *api.ThinkValue
|
Think *api.ThinkValue
|
||||||
HideThinking bool
|
HideThinking bool
|
||||||
|
Verbose bool
|
||||||
|
|
||||||
// Agent fields (managed externally for session persistence)
|
// Agent fields (managed externally for session persistence)
|
||||||
Tools *tools.Registry
|
Tools *tools.Registry
|
||||||
@@ -178,6 +188,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
|||||||
var thinkTagClosed bool = false
|
var thinkTagClosed bool = false
|
||||||
var pendingToolCalls []api.ToolCall
|
var pendingToolCalls []api.ToolCall
|
||||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||||
|
var latest api.ChatResponse
|
||||||
|
|
||||||
role := "assistant"
|
role := "assistant"
|
||||||
messages := opts.Messages
|
messages := opts.Messages
|
||||||
@@ -187,6 +198,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
|||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
latest = response
|
||||||
role = response.Message.Role
|
role = response.Message.Role
|
||||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||||
if !thinkTagOpened {
|
if !thinkTagOpened {
|
||||||
@@ -483,6 +495,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
|||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.Verbose {
|
||||||
|
latest.Summary()
|
||||||
|
}
|
||||||
|
|
||||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -634,7 +650,8 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
|||||||
// GenerateInteractive runs an interactive agent session.
|
// GenerateInteractive runs an interactive agent session.
|
||||||
// This is called from cmd.go when --experimental flag is set.
|
// This is called from cmd.go when --experimental flag is set.
|
||||||
// If yoloMode is true, all tool approvals are skipped.
|
// If yoloMode is true, all tool approvals are skipped.
|
||||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
// If enableWebsearch is true, the web search tool is registered.
|
||||||
|
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
|
||||||
scanner, err := readline.New(readline.Prompt{
|
scanner, err := readline.New(readline.Prompt{
|
||||||
Prompt: ">>> ",
|
Prompt: ">>> ",
|
||||||
AltPrompt: "... ",
|
AltPrompt: "... ",
|
||||||
@@ -660,6 +677,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
if supportsTools {
|
if supportsTools {
|
||||||
toolRegistry = tools.DefaultRegistry()
|
toolRegistry = tools.DefaultRegistry()
|
||||||
|
|
||||||
|
// Register web search and web fetch tools if enabled via flag
|
||||||
|
if enableWebsearch {
|
||||||
|
toolRegistry.RegisterWebSearch()
|
||||||
|
toolRegistry.RegisterWebFetch()
|
||||||
|
}
|
||||||
|
|
||||||
if toolRegistry.Has("bash") {
|
if toolRegistry.Has("bash") {
|
||||||
fmt.Fprintln(os.Stderr)
|
fmt.Fprintln(os.Stderr)
|
||||||
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
|
||||||
@@ -667,6 +690,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
fmt.Fprintln(os.Stderr)
|
fmt.Fprintln(os.Stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
|
||||||
|
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
}
|
||||||
|
|
||||||
if yoloMode {
|
if yoloMode {
|
||||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||||
}
|
}
|
||||||
@@ -677,6 +705,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
|
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
var format string
|
||||||
|
var system string
|
||||||
|
var multiline MultilineState = MultilineNone
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := scanner.Readline()
|
line, err := scanner.Readline()
|
||||||
@@ -688,13 +719,39 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
if line == "" {
|
if line == "" {
|
||||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||||
}
|
}
|
||||||
|
scanner.Prompt.UseAlt = false
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
|
multiline = MultilineNone
|
||||||
continue
|
continue
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
case multiline != MultilineNone:
|
||||||
|
// check if there's a multiline terminating string
|
||||||
|
before, ok := strings.CutSuffix(line, `"""`)
|
||||||
|
sb.WriteString(before)
|
||||||
|
if !ok {
|
||||||
|
fmt.Fprintln(&sb)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch multiline {
|
||||||
|
case MultilineSystem:
|
||||||
|
system = sb.String()
|
||||||
|
newMessage := api.Message{Role: "system", Content: system}
|
||||||
|
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||||
|
messages[len(messages)-1] = newMessage
|
||||||
|
} else {
|
||||||
|
messages = append(messages, newMessage)
|
||||||
|
}
|
||||||
|
fmt.Println("Set system message.")
|
||||||
|
sb.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
multiline = MultilineNone
|
||||||
|
scanner.Prompt.UseAlt = false
|
||||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||||
return nil
|
return nil
|
||||||
case strings.HasPrefix(line, "/clear"):
|
case strings.HasPrefix(line, "/clear"):
|
||||||
@@ -707,6 +764,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
continue
|
continue
|
||||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||||
|
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||||
|
fmt.Fprintln(os.Stderr, " /load Load a different model")
|
||||||
|
fmt.Fprintln(os.Stderr, " /save Save session as a model")
|
||||||
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||||
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||||
@@ -716,6 +777,303 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
continue
|
continue
|
||||||
|
case strings.HasPrefix(line, "/set"):
|
||||||
|
args := strings.Fields(line)
|
||||||
|
if len(args) > 1 {
|
||||||
|
switch args[1] {
|
||||||
|
case "history":
|
||||||
|
scanner.HistoryEnable()
|
||||||
|
case "nohistory":
|
||||||
|
scanner.HistoryDisable()
|
||||||
|
case "wordwrap":
|
||||||
|
wordWrap = true
|
||||||
|
fmt.Println("Set 'wordwrap' mode.")
|
||||||
|
case "nowordwrap":
|
||||||
|
wordWrap = false
|
||||||
|
fmt.Println("Set 'nowordwrap' mode.")
|
||||||
|
case "verbose":
|
||||||
|
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("Set 'verbose' mode.")
|
||||||
|
case "quiet":
|
||||||
|
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("Set 'quiet' mode.")
|
||||||
|
case "think":
|
||||||
|
thinkValue := api.ThinkValue{Value: true}
|
||||||
|
var maybeLevel string
|
||||||
|
if len(args) > 2 {
|
||||||
|
maybeLevel = args[2]
|
||||||
|
}
|
||||||
|
if maybeLevel != "" {
|
||||||
|
thinkValue.Value = maybeLevel
|
||||||
|
}
|
||||||
|
think = &thinkValue
|
||||||
|
// Check if model supports thinking
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||||
|
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||||
|
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if maybeLevel != "" {
|
||||||
|
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||||
|
} else {
|
||||||
|
fmt.Println("Set 'think' mode.")
|
||||||
|
}
|
||||||
|
case "nothink":
|
||||||
|
think = &api.ThinkValue{Value: false}
|
||||||
|
// Check if model supports thinking
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
|
||||||
|
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||||
|
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println("Set 'nothink' mode.")
|
||||||
|
case "format":
|
||||||
|
if len(args) < 3 || args[2] != "json" {
|
||||||
|
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||||
|
} else {
|
||||||
|
format = args[2]
|
||||||
|
fmt.Printf("Set format to '%s' mode.\n", args[2])
|
||||||
|
}
|
||||||
|
case "noformat":
|
||||||
|
format = ""
|
||||||
|
fmt.Println("Disabled format.")
|
||||||
|
case "parameter":
|
||||||
|
if len(args) < 4 {
|
||||||
|
fmt.Println("Usage: /set parameter <name> <value>")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
params := args[3:]
|
||||||
|
fp, err := api.FormatParams(map[string][]string{args[2]: params})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Couldn't set parameter: %q\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||||
|
options[args[2]] = fp[args[2]]
|
||||||
|
case "system":
|
||||||
|
if len(args) < 3 {
|
||||||
|
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
multiline = MultilineSystem
|
||||||
|
|
||||||
|
line := strings.Join(args[2:], " ")
|
||||||
|
line, ok := strings.CutPrefix(line, `"""`)
|
||||||
|
if !ok {
|
||||||
|
multiline = MultilineNone
|
||||||
|
} else {
|
||||||
|
// only cut suffix if the line is multiline
|
||||||
|
line, ok = strings.CutSuffix(line, `"""`)
|
||||||
|
if ok {
|
||||||
|
multiline = MultilineNone
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(line)
|
||||||
|
if multiline != MultilineNone {
|
||||||
|
scanner.Prompt.UseAlt = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
system = sb.String()
|
||||||
|
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||||
|
// Check if the slice is not empty and the last message is from 'system'
|
||||||
|
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||||
|
// Replace the last message
|
||||||
|
messages[len(messages)-1] = newMessage
|
||||||
|
} else {
|
||||||
|
messages = append(messages, newMessage)
|
||||||
|
}
|
||||||
|
fmt.Println("Set system message.")
|
||||||
|
sb.Reset()
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
case strings.HasPrefix(line, "/show"):
|
||||||
|
args := strings.Fields(line)
|
||||||
|
if len(args) > 1 {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't connect to ollama server")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req := &api.ShowRequest{
|
||||||
|
Name: modelName,
|
||||||
|
Options: options,
|
||||||
|
}
|
||||||
|
resp, err := client.Show(cmd.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't get model")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch args[1] {
|
||||||
|
case "info":
|
||||||
|
fmt.Fprintf(os.Stderr, " Model\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
|
||||||
|
if resp.Details.Family != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
|
||||||
|
}
|
||||||
|
if resp.Details.ParameterSize != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
|
||||||
|
}
|
||||||
|
if resp.Details.QuantizationLevel != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
|
||||||
|
}
|
||||||
|
if len(resp.Capabilities) > 0 {
|
||||||
|
caps := make([]string, len(resp.Capabilities))
|
||||||
|
for i, c := range resp.Capabilities {
|
||||||
|
caps[i] = string(c)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
|
||||||
|
}
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
case "license":
|
||||||
|
if resp.License == "" {
|
||||||
|
fmt.Println("No license was specified for this model.")
|
||||||
|
} else {
|
||||||
|
fmt.Println(resp.License)
|
||||||
|
}
|
||||||
|
case "modelfile":
|
||||||
|
fmt.Println(resp.Modelfile)
|
||||||
|
case "parameters":
|
||||||
|
fmt.Println("Model defined parameters:")
|
||||||
|
if resp.Parameters == "" {
|
||||||
|
fmt.Println(" No additional parameters were specified.")
|
||||||
|
} else {
|
||||||
|
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||||
|
fmt.Printf(" %s\n", l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(options) > 0 {
|
||||||
|
fmt.Println("\nUser defined parameters:")
|
||||||
|
for k, v := range options {
|
||||||
|
fmt.Printf(" %-30s %v\n", k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "system":
|
||||||
|
switch {
|
||||||
|
case system != "":
|
||||||
|
fmt.Println(system + "\n")
|
||||||
|
case resp.System != "":
|
||||||
|
fmt.Println(resp.System + "\n")
|
||||||
|
default:
|
||||||
|
fmt.Println("No system message was specified for this model.")
|
||||||
|
}
|
||||||
|
case "template":
|
||||||
|
if resp.Template != "" {
|
||||||
|
fmt.Println(resp.Template)
|
||||||
|
} else {
|
||||||
|
fmt.Println("No prompt template was specified for this model.")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
case strings.HasPrefix(line, "/load"):
|
||||||
|
args := strings.Fields(line)
|
||||||
|
if len(args) != 2 {
|
||||||
|
fmt.Println("Usage: /load <modelname>")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newModelName := args[1]
|
||||||
|
fmt.Printf("Loading model '%s'\n", newModelName)
|
||||||
|
|
||||||
|
// Create progress spinner
|
||||||
|
p := progress.NewProgress(os.Stderr)
|
||||||
|
spinner := progress.NewSpinner("")
|
||||||
|
p.Add("", spinner)
|
||||||
|
|
||||||
|
// Get client
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
p.StopAndClear()
|
||||||
|
fmt.Println("error: couldn't connect to ollama server")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if model exists and get its info
|
||||||
|
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
|
||||||
|
if err != nil {
|
||||||
|
p.StopAndClear()
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// For cloud models, no need to preload
|
||||||
|
if info.RemoteHost == "" {
|
||||||
|
// Preload the model by sending an empty generate request
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: newModelName,
|
||||||
|
Think: think,
|
||||||
|
}
|
||||||
|
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
p.StopAndClear()
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
fmt.Printf("Couldn't find model '%s'\n", newModelName)
|
||||||
|
} else if strings.Contains(err.Error(), "does not support thinking") {
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("error loading model: %v\n", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.StopAndClear()
|
||||||
|
modelName = newModelName
|
||||||
|
messages = []api.Message{}
|
||||||
|
approval.Reset()
|
||||||
|
continue
|
||||||
|
case strings.HasPrefix(line, "/save"):
|
||||||
|
args := strings.Fields(line)
|
||||||
|
if len(args) != 2 {
|
||||||
|
fmt.Println("Usage: /save <modelname>")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't connect to ollama server")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req := &api.CreateRequest{
|
||||||
|
Model: args[1],
|
||||||
|
From: modelName,
|
||||||
|
Parameters: options,
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
fn := func(resp api.ProgressResponse) error { return nil }
|
||||||
|
err = client.Create(cmd.Context(), req, fn)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Printf("Created new model '%s'\n", args[1])
|
||||||
|
continue
|
||||||
case strings.HasPrefix(line, "/"):
|
case strings.HasPrefix(line, "/"):
|
||||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||||
continue
|
continue
|
||||||
@@ -723,14 +1081,16 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
sb.WriteString(line)
|
sb.WriteString(line)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sb.Len() > 0 {
|
if sb.Len() > 0 && multiline == MultilineNone {
|
||||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||||
messages = append(messages, newMessage)
|
messages = append(messages, newMessage)
|
||||||
|
|
||||||
|
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||||
opts := RunOptions{
|
opts := RunOptions{
|
||||||
Model: modelName,
|
Model: modelName,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
WordWrap: wordWrap,
|
WordWrap: wordWrap,
|
||||||
|
Format: format,
|
||||||
Options: options,
|
Options: options,
|
||||||
Think: think,
|
Think: think,
|
||||||
HideThinking: hideThinking,
|
HideThinking: hideThinking,
|
||||||
@@ -738,6 +1098,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
|||||||
Tools: toolRegistry,
|
Tools: toolRegistry,
|
||||||
Approval: approval,
|
Approval: approval,
|
||||||
YoloMode: yoloMode,
|
YoloMode: yoloMode,
|
||||||
|
Verbose: verbose,
|
||||||
}
|
}
|
||||||
|
|
||||||
assistant, err := Chat(cmd.Context(), opts)
|
assistant, err := Chat(cmd.Context(), opts)
|
||||||
|
|||||||
@@ -234,3 +234,17 @@ ollama create z-image
|
|||||||
3. Copy config files (*.json) as config layers
|
3. Copy config files (*.json) as config layers
|
||||||
4. Write manifest
|
4. Write manifest
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## FP8 Quantization
|
||||||
|
|
||||||
|
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ./weights/Z-Image-Turbo
|
||||||
|
ollama create z-image-fp8 --quantize fp8
|
||||||
|
```
|
||||||
|
|
||||||
|
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -101,10 +99,10 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
|
|||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
|
|
||||||
var imagePath string
|
var imageBase64 string
|
||||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||||
if resp.Done {
|
if resp.Done {
|
||||||
imagePath = extractPath(resp.Content)
|
imageBase64 = extractBase64(resp.Content)
|
||||||
} else {
|
} else {
|
||||||
progress := parseProgress(resp.Content)
|
progress := parseProgress(resp.Content)
|
||||||
if progress.Total > 0 {
|
if progress.Total > 0 {
|
||||||
@@ -118,14 +116,14 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SSEvent("done", buildResponse(imagePath, format))
|
c.SSEvent("done", buildResponse(imageBase64, format))
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||||
var imagePath string
|
var imageBase64 string
|
||||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||||
if resp.Done {
|
if resp.Done {
|
||||||
imagePath = extractPath(resp.Content)
|
imageBase64 = extractBase64(resp.Content)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -133,7 +131,7 @@ func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, buildResponse(imagePath, format))
|
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSize(size string) (int32, int32) {
|
func parseSize(size string) (int32, int32) {
|
||||||
@@ -152,9 +150,9 @@ func parseSize(size string) (int32, int32) {
|
|||||||
return int32(w), int32(h)
|
return int32(w), int32(h)
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractPath(content string) string {
|
func extractBase64(content string) string {
|
||||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||||
return strings.TrimSpace(content[idx+16:])
|
return content[13:]
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -165,23 +163,21 @@ func parseProgress(content string) ImageProgressEvent {
|
|||||||
return ImageProgressEvent{Step: step, Total: total}
|
return ImageProgressEvent{Step: step, Total: total}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildResponse(imagePath, format string) ImageGenerationResponse {
|
func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||||
resp := ImageGenerationResponse{
|
resp := ImageGenerationResponse{
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
Data: make([]ImageData, 1),
|
Data: make([]ImageData, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
if imagePath == "" {
|
if imageBase64 == "" {
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
if format == "url" {
|
if format == "url" {
|
||||||
resp.Data[0].URL = "file://" + imagePath
|
// URL format not supported when using base64 transfer
|
||||||
|
resp.Data[0].B64JSON = imageBase64
|
||||||
} else {
|
} else {
|
||||||
data, err := os.ReadFile(imagePath)
|
resp.Data[0].B64JSON = imageBase64
|
||||||
if err == nil {
|
|
||||||
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|||||||
197
x/imagegen/cache/teacache.go
vendored
Normal file
197
x/imagegen/cache/teacache.go
vendored
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
// Package cache provides caching mechanisms for diffusion model inference.
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
|
||||||
|
// It caches the transformer output and reuses it when timestep values
|
||||||
|
// are similar between consecutive steps.
|
||||||
|
//
|
||||||
|
// For CFG (classifier-free guidance), it caches pos and neg predictions
|
||||||
|
// separately and always computes CFG fresh to avoid error amplification.
|
||||||
|
//
|
||||||
|
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
|
||||||
|
// https://github.com/ali-vilab/TeaCache
|
||||||
|
type TeaCache struct {
|
||||||
|
// Cached transformer output from last computed step (non-CFG mode)
|
||||||
|
cachedOutput *mlx.Array
|
||||||
|
|
||||||
|
// Cached CFG outputs (pos and neg separately)
|
||||||
|
cachedPosOutput *mlx.Array
|
||||||
|
cachedNegOutput *mlx.Array
|
||||||
|
|
||||||
|
// Previous timestep value for difference calculation
|
||||||
|
prevTimestep float32
|
||||||
|
|
||||||
|
// Accumulated difference for rescaling
|
||||||
|
accumulatedDiff float32
|
||||||
|
|
||||||
|
// Configuration
|
||||||
|
threshold float32 // Threshold for recomputation decision
|
||||||
|
rescaleFactor float32 // Model-specific rescaling factor
|
||||||
|
skipEarlySteps int // Number of early steps to never cache
|
||||||
|
|
||||||
|
// Statistics
|
||||||
|
cacheHits int
|
||||||
|
cacheMisses int
|
||||||
|
}
|
||||||
|
|
||||||
|
// TeaCacheConfig holds configuration for TeaCache.
|
||||||
|
type TeaCacheConfig struct {
|
||||||
|
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
|
||||||
|
// Recommended: 0.05-0.15 for image models
|
||||||
|
Threshold float32
|
||||||
|
|
||||||
|
// Rescale factor to adjust timestep embedding differences.
|
||||||
|
// Model-specific, typically 1.0-2.0
|
||||||
|
RescaleFactor float32
|
||||||
|
|
||||||
|
// SkipEarlySteps: number of early steps to always compute (never cache).
|
||||||
|
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
|
||||||
|
SkipEarlySteps int
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultTeaCacheConfig returns default configuration for TeaCache.
|
||||||
|
func DefaultTeaCacheConfig() *TeaCacheConfig {
|
||||||
|
return &TeaCacheConfig{
|
||||||
|
Threshold: 0.1,
|
||||||
|
RescaleFactor: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTeaCache creates a new TeaCache instance.
|
||||||
|
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = DefaultTeaCacheConfig()
|
||||||
|
}
|
||||||
|
return &TeaCache{
|
||||||
|
threshold: cfg.Threshold,
|
||||||
|
rescaleFactor: cfg.RescaleFactor,
|
||||||
|
skipEarlySteps: cfg.SkipEarlySteps,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldCompute determines if we should compute the full forward pass
|
||||||
|
// or reuse the cached output based on timestep similarity.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// 1. First step always computes
|
||||||
|
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
|
||||||
|
// 3. If accumulated difference > threshold, compute new output
|
||||||
|
// 4. Otherwise, reuse cached output
|
||||||
|
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
|
||||||
|
// Always compute early steps (critical for structure)
|
||||||
|
// Check both regular cache and CFG cache
|
||||||
|
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
|
||||||
|
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute absolute difference between current and previous timestep
|
||||||
|
diff := timestep - tc.prevTimestep
|
||||||
|
if diff < 0 {
|
||||||
|
diff = -diff
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply rescaling factor
|
||||||
|
scaledDiff := diff * tc.rescaleFactor
|
||||||
|
|
||||||
|
// Accumulate difference (helps track drift over multiple cached steps)
|
||||||
|
tc.accumulatedDiff += scaledDiff
|
||||||
|
|
||||||
|
// Decision based on accumulated difference
|
||||||
|
if tc.accumulatedDiff > tc.threshold {
|
||||||
|
tc.accumulatedDiff = 0 // Reset accumulator
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
|
||||||
|
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
|
||||||
|
// Free previous cached output
|
||||||
|
if tc.cachedOutput != nil {
|
||||||
|
tc.cachedOutput.Free()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store new cached values
|
||||||
|
tc.cachedOutput = output
|
||||||
|
tc.prevTimestep = timestep
|
||||||
|
tc.cacheMisses++
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
|
||||||
|
// This allows CFG to be computed fresh each step, avoiding error amplification.
|
||||||
|
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
|
||||||
|
// Free previous cached outputs
|
||||||
|
if tc.cachedPosOutput != nil {
|
||||||
|
tc.cachedPosOutput.Free()
|
||||||
|
}
|
||||||
|
if tc.cachedNegOutput != nil {
|
||||||
|
tc.cachedNegOutput.Free()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store new cached values
|
||||||
|
tc.cachedPosOutput = posOutput
|
||||||
|
tc.cachedNegOutput = negOutput
|
||||||
|
tc.prevTimestep = timestep
|
||||||
|
tc.cacheMisses++
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCached returns the cached output (non-CFG mode).
|
||||||
|
func (tc *TeaCache) GetCached() *mlx.Array {
|
||||||
|
tc.cacheHits++
|
||||||
|
return tc.cachedOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCFGCached returns cached pos and neg outputs for CFG mode.
|
||||||
|
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
|
||||||
|
tc.cacheHits++
|
||||||
|
return tc.cachedPosOutput, tc.cachedNegOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasCFGCache returns true if CFG cache is available.
|
||||||
|
func (tc *TeaCache) HasCFGCache() bool {
|
||||||
|
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Arrays returns all arrays that should be kept alive.
|
||||||
|
func (tc *TeaCache) Arrays() []*mlx.Array {
|
||||||
|
var arrays []*mlx.Array
|
||||||
|
if tc.cachedOutput != nil {
|
||||||
|
arrays = append(arrays, tc.cachedOutput)
|
||||||
|
}
|
||||||
|
if tc.cachedPosOutput != nil {
|
||||||
|
arrays = append(arrays, tc.cachedPosOutput)
|
||||||
|
}
|
||||||
|
if tc.cachedNegOutput != nil {
|
||||||
|
arrays = append(arrays, tc.cachedNegOutput)
|
||||||
|
}
|
||||||
|
return arrays
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns cache hit/miss statistics.
|
||||||
|
func (tc *TeaCache) Stats() (hits, misses int) {
|
||||||
|
return tc.cacheHits, tc.cacheMisses
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free releases all cached arrays.
|
||||||
|
func (tc *TeaCache) Free() {
|
||||||
|
if tc.cachedOutput != nil {
|
||||||
|
tc.cachedOutput.Free()
|
||||||
|
tc.cachedOutput = nil
|
||||||
|
}
|
||||||
|
if tc.cachedPosOutput != nil {
|
||||||
|
tc.cachedPosOutput.Free()
|
||||||
|
tc.cachedPosOutput = nil
|
||||||
|
}
|
||||||
|
if tc.cachedNegOutput != nil {
|
||||||
|
tc.cachedNegOutput.Free()
|
||||||
|
tc.cachedNegOutput = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -44,62 +44,64 @@ func DefaultOptions() ImageGenOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Show displays information about an image generation model.
|
// ModelInfo contains metadata about an image generation model.
|
||||||
func Show(modelName string, w io.Writer) error {
|
type ModelInfo struct {
|
||||||
manifest, err := LoadManifest(modelName)
|
Architecture string
|
||||||
if err != nil {
|
ParameterCount int64
|
||||||
return fmt.Errorf("failed to load manifest: %w", err)
|
Quantization string
|
||||||
}
|
|
||||||
|
|
||||||
// Count total size
|
|
||||||
var totalSize int64
|
|
||||||
for _, layer := range manifest.Manifest.Layers {
|
|
||||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
|
||||||
totalSize += layer.Size
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read model_index.json for architecture
|
|
||||||
var architecture string
|
|
||||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
|
||||||
var index struct {
|
|
||||||
Architecture string `json:"architecture"`
|
|
||||||
}
|
|
||||||
if json.Unmarshal(data, &index) == nil {
|
|
||||||
architecture = index.Architecture
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
|
|
||||||
paramCount := totalSize / 2
|
|
||||||
paramStr := formatParamCount(paramCount)
|
|
||||||
|
|
||||||
// Print Model info
|
|
||||||
fmt.Fprintln(w, " Model")
|
|
||||||
if architecture != "" {
|
|
||||||
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
|
|
||||||
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
|
|
||||||
fmt.Fprintln(w)
|
|
||||||
|
|
||||||
// Print Capabilities
|
|
||||||
fmt.Fprintln(w, " Capabilities")
|
|
||||||
fmt.Fprintf(w, " %s\n", "image")
|
|
||||||
fmt.Fprintln(w)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatParamCount formats parameter count as human-readable string.
|
// GetModelInfo returns metadata about an image generation model.
|
||||||
func formatParamCount(count int64) string {
|
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||||
if count >= 1_000_000_000 {
|
manifest, err := LoadManifest(modelName)
|
||||||
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||||
}
|
}
|
||||||
if count >= 1_000_000 {
|
|
||||||
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
|
info := &ModelInfo{}
|
||||||
|
|
||||||
|
// Read model_index.json for architecture, parameter count, and quantization
|
||||||
|
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||||
|
var index struct {
|
||||||
|
Architecture string `json:"architecture"`
|
||||||
|
ParameterCount int64 `json:"parameter_count"`
|
||||||
|
Quantization string `json:"quantization"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(data, &index) == nil {
|
||||||
|
info.Architecture = index.Architecture
|
||||||
|
info.ParameterCount = index.ParameterCount
|
||||||
|
info.Quantization = index.Quantization
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%d", count)
|
|
||||||
|
// Fallback: detect quantization from tensor names if not in config
|
||||||
|
if info.Quantization == "" {
|
||||||
|
for _, layer := range manifest.Manifest.Layers {
|
||||||
|
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||||
|
info.Quantization = "FP8"
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info.Quantization == "" {
|
||||||
|
info.Quantization = "BF16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: estimate parameter count if not in config
|
||||||
|
if info.ParameterCount == 0 {
|
||||||
|
var totalSize int64
|
||||||
|
for _, layer := range manifest.Manifest.Layers {
|
||||||
|
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||||
|
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||||
|
totalSize += layer.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Assume BF16 (2 bytes/param) as rough estimate
|
||||||
|
info.ParameterCount = totalSize / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterFlags adds image generation flags to the given command.
|
// RegisterFlags adds image generation flags to the given command.
|
||||||
@@ -121,11 +123,6 @@ func RegisterFlags(cmd *cobra.Command) {
|
|||||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||||
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
|
||||||
// Verify it's a valid image gen model
|
|
||||||
if ResolveModelName(name) == "" {
|
|
||||||
return fmt.Errorf("unknown image generation model: %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get options from flags (with env var defaults)
|
// Get options from flags (with env var defaults)
|
||||||
opts := DefaultOptions()
|
opts := DefaultOptions()
|
||||||
if cmd != nil && cmd.Flags() != nil {
|
if cmd != nil && cmd.Flags() != nil {
|
||||||
@@ -183,8 +180,7 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
|||||||
p.Add("", spinner)
|
p.Add("", spinner)
|
||||||
|
|
||||||
var stepBar *progress.StepBar
|
var stepBar *progress.StepBar
|
||||||
var imagePath string
|
var imageBase64 string
|
||||||
|
|
||||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||||
content := resp.Response
|
content := resp.Response
|
||||||
|
|
||||||
@@ -203,11 +199,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle final response with image path
|
// Handle final response with base64 image data
|
||||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
imageBase64 = content[13:]
|
||||||
imagePath = strings.TrimSpace(content[idx+16:])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -218,9 +212,27 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if imagePath != "" {
|
if imageBase64 != "" {
|
||||||
displayImageInTerminal(imagePath)
|
// Decode base64 and save to CWD
|
||||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to decode image: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create filename from prompt
|
||||||
|
safeName := sanitizeFilename(prompt)
|
||||||
|
if len(safeName) > 50 {
|
||||||
|
safeName = safeName[:50]
|
||||||
|
}
|
||||||
|
timestamp := time.Now().Format("20060102-150405")
|
||||||
|
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||||
|
|
||||||
|
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||||
|
return fmt.Errorf("failed to save image: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
displayImageInTerminal(filename)
|
||||||
|
fmt.Printf("Image saved to: %s\n", filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -306,7 +318,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
|||||||
p.Add("", spinner)
|
p.Add("", spinner)
|
||||||
|
|
||||||
var stepBar *progress.StepBar
|
var stepBar *progress.StepBar
|
||||||
var imagePath string
|
var imageBase64 string
|
||||||
|
|
||||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||||
content := resp.Response
|
content := resp.Response
|
||||||
@@ -326,11 +338,9 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle final response with image path
|
// Handle final response with base64 image data
|
||||||
if resp.Done && strings.Contains(content, "Image saved to:") {
|
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||||
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
|
imageBase64 = content[13:]
|
||||||
imagePath = strings.TrimSpace(content[idx+16:])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -342,25 +352,30 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy image to current directory with descriptive name
|
// Save image to current directory with descriptive name
|
||||||
if imagePath != "" {
|
if imageBase64 != "" {
|
||||||
|
// Decode base64 image data
|
||||||
|
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Create filename from prompt (sanitized)
|
// Create filename from prompt (sanitized)
|
||||||
safeName := sanitizeFilename(line)
|
safeName := sanitizeFilename(line)
|
||||||
if len(safeName) > 50 {
|
if len(safeName) > 50 {
|
||||||
safeName = safeName[:50]
|
safeName = safeName[:50]
|
||||||
}
|
}
|
||||||
timestamp := time.Now().Format("20060102-150405")
|
timestamp := time.Now().Format("20060102-150405")
|
||||||
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
|
||||||
|
|
||||||
// Copy file to CWD
|
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
|
||||||
if err := copyFile(imagePath, newName); err != nil {
|
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
|
||||||
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
|
continue
|
||||||
displayImageInTerminal(imagePath)
|
|
||||||
fmt.Printf("Image saved to: %s\n", imagePath)
|
|
||||||
} else {
|
|
||||||
displayImageInTerminal(newName)
|
|
||||||
fmt.Printf("Image saved to: %s\n", newName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
displayImageInTerminal(filename)
|
||||||
|
fmt.Printf("Image saved to: %s\n", filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
@@ -381,24 +396,6 @@ func sanitizeFilename(s string) string {
|
|||||||
return result.String()
|
return result.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// copyFile copies a file from src to dst.
|
|
||||||
func copyFile(src, dst string) error {
|
|
||||||
sourceFile, err := os.Open(src)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer sourceFile.Close()
|
|
||||||
|
|
||||||
destFile, err := os.Create(dst)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer destFile.Close()
|
|
||||||
|
|
||||||
_, err = io.Copy(destFile, sourceFile)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// printInteractiveHelp prints help for interactive mode commands.
|
// printInteractiveHelp prints help for interactive mode commands.
|
||||||
func printInteractiveHelp(opts ImageGenOptions) {
|
func printInteractiveHelp(opts ImageGenOptions) {
|
||||||
fmt.Fprintln(os.Stderr, "Commands:")
|
fmt.Fprintln(os.Stderr, "Commands:")
|
||||||
@@ -509,10 +506,7 @@ func displayImageInTerminal(imagePath string) bool {
|
|||||||
// Send in chunks for large images
|
// Send in chunks for large images
|
||||||
const chunkSize = 4096
|
const chunkSize = 4096
|
||||||
for i := 0; i < len(encoded); i += chunkSize {
|
for i := 0; i < len(encoded); i += chunkSize {
|
||||||
end := i + chunkSize
|
end := min(i+chunkSize, len(encoded))
|
||||||
if end > len(encoded) {
|
|
||||||
end = len(encoded)
|
|
||||||
}
|
|
||||||
chunk := encoded[i:end]
|
chunk := encoded[i:end]
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
|
|||||||
@@ -29,9 +29,10 @@ const MinOllamaVersion = "0.14.0"
|
|||||||
|
|
||||||
// CreateModel imports a tensor-based model from a local directory.
|
// CreateModel imports a tensor-based model from a local directory.
|
||||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||||
|
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
|
||||||
//
|
//
|
||||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||||
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
|
||||||
if !imagegen.IsTensorModelDir(modelDir) {
|
if !imagegen.IsTensorModelDir(modelDir) {
|
||||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||||
}
|
}
|
||||||
@@ -58,18 +59,77 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
|||||||
|
|
||||||
// Create tensor layer callback for individual tensors
|
// Create tensor layer callback for individual tensors
|
||||||
// name is path-style: "component/tensor_name"
|
// name is path-style: "component/tensor_name"
|
||||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
|
// When quantize is true, returns multiple layers (weight + scales)
|
||||||
|
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
|
||||||
|
if doQuantize {
|
||||||
|
// Check if quantization is supported
|
||||||
|
if !QuantizeSupported() {
|
||||||
|
return nil, fmt.Errorf("quantization requires MLX support")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quantize the tensor (affine mode returns weight, scales, qbiases)
|
||||||
|
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create layer for quantized weight
|
||||||
|
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create layer for scales (use _scale suffix convention)
|
||||||
|
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
layers := []imagegen.LayerInfo{
|
||||||
|
{
|
||||||
|
Digest: weightLayer.Digest,
|
||||||
|
Size: weightLayer.Size,
|
||||||
|
MediaType: weightLayer.MediaType,
|
||||||
|
Name: name, // Keep original name for weight
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Digest: scalesLayer.Digest,
|
||||||
|
Size: scalesLayer.Size,
|
||||||
|
MediaType: scalesLayer.MediaType,
|
||||||
|
Name: name + "_scale", // Add _scale suffix
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add qbiases layer if present (affine mode)
|
||||||
|
if qbiasData != nil {
|
||||||
|
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
layers = append(layers, imagegen.LayerInfo{
|
||||||
|
Digest: qbiasLayer.Digest,
|
||||||
|
Size: qbiasLayer.Size,
|
||||||
|
MediaType: qbiasLayer.MediaType,
|
||||||
|
Name: name + "_qbias", // Add _qbias suffix
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return layers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-quantized path: just create a single layer
|
||||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return imagegen.LayerInfo{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
layer.Name = name
|
|
||||||
|
|
||||||
return imagegen.LayerInfo{
|
return []imagegen.LayerInfo{
|
||||||
Digest: layer.Digest,
|
{
|
||||||
Size: layer.Size,
|
Digest: layer.Digest,
|
||||||
MediaType: layer.MediaType,
|
Size: layer.Size,
|
||||||
Name: name,
|
MediaType: layer.MediaType,
|
||||||
|
Name: name,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +179,7 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
|
|||||||
p.Add("imagegen", spinner)
|
p.Add("imagegen", spinner)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
|
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
120
x/imagegen/client/quantize.go
Normal file
120
x/imagegen/client/quantize.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
|
||||||
|
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||||
|
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||||
|
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||||
|
tmpDir := ensureTempDir()
|
||||||
|
|
||||||
|
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||||
|
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||||
|
}
|
||||||
|
tmpPath := tmpFile.Name()
|
||||||
|
defer os.Remove(tmpPath)
|
||||||
|
|
||||||
|
if _, err := io.Copy(tmpFile, r); err != nil {
|
||||||
|
tmpFile.Close()
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
|
||||||
|
}
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
// Load the tensor using MLX's native loader
|
||||||
|
st, err := mlx.LoadSafetensorsNative(tmpPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
|
||||||
|
}
|
||||||
|
defer st.Free()
|
||||||
|
|
||||||
|
// Get the tensor (it's stored as "data" in our minimal safetensors format)
|
||||||
|
arr := st.Get("data")
|
||||||
|
if arr == nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to BFloat16 if needed (quantize expects float type)
|
||||||
|
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||||
|
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||||
|
mlx.Eval(arr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quantize with affine mode: group_size=32, bits=8
|
||||||
|
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
|
||||||
|
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
|
||||||
|
|
||||||
|
// Eval and make contiguous for data access
|
||||||
|
qweight = mlx.Contiguous(qweight)
|
||||||
|
scales = mlx.Contiguous(scales)
|
||||||
|
if qbiases != nil {
|
||||||
|
qbiases = mlx.Contiguous(qbiases)
|
||||||
|
mlx.Eval(qweight, scales, qbiases)
|
||||||
|
} else {
|
||||||
|
mlx.Eval(qweight, scales)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get shapes
|
||||||
|
qweightShape = qweight.Shape()
|
||||||
|
scalesShape = scales.Shape()
|
||||||
|
|
||||||
|
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
|
||||||
|
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
|
||||||
|
defer os.Remove(qweightPath)
|
||||||
|
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
|
||||||
|
}
|
||||||
|
qweightData, err = os.ReadFile(qweightPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save scales using MLX's native safetensors
|
||||||
|
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
|
||||||
|
defer os.Remove(scalesPath)
|
||||||
|
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
|
||||||
|
}
|
||||||
|
scalesData, err = os.ReadFile(scalesPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Affine mode returns qbiases for zero-point offset
|
||||||
|
if qbiases != nil {
|
||||||
|
qbiasShape = qbiases.Shape()
|
||||||
|
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
|
||||||
|
defer os.Remove(qbiasPath)
|
||||||
|
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
|
||||||
|
}
|
||||||
|
qbiasData, err = os.ReadFile(qbiasPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||||
|
func QuantizeSupported() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureTempDir creates the temp directory for quantization if it doesn't exist
|
||||||
|
func ensureTempDir() string {
|
||||||
|
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
|
||||||
|
os.MkdirAll(tmpDir, 0755)
|
||||||
|
return tmpDir
|
||||||
|
}
|
||||||
18
x/imagegen/client/quantize_stub.go
Normal file
18
x/imagegen/client/quantize_stub.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
//go:build !mlx
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// quantizeTensor is not available without MLX
|
||||||
|
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||||
|
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuantizeSupported returns false when MLX is not available
|
||||||
|
func QuantizeSupported() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -67,6 +67,9 @@ func main() {
|
|||||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||||
|
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
|
||||||
|
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
|
||||||
|
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
@@ -99,13 +102,17 @@ func main() {
|
|||||||
}
|
}
|
||||||
var img *mlx.Array
|
var img *mlx.Array
|
||||||
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
||||||
Prompt: *prompt,
|
Prompt: *prompt,
|
||||||
Width: int32(*width),
|
NegativePrompt: *negativePrompt,
|
||||||
Height: int32(*height),
|
CFGScale: float32(*cfgScale),
|
||||||
Steps: *steps,
|
Width: int32(*width),
|
||||||
Seed: *seed,
|
Height: int32(*height),
|
||||||
CapturePath: *gpuCapture,
|
Steps: *steps,
|
||||||
LayerCache: *layerCache,
|
Seed: *seed,
|
||||||
|
CapturePath: *gpuCapture,
|
||||||
|
TeaCache: *teaCache,
|
||||||
|
TeaCacheThreshold: float32(*teaCacheThreshold),
|
||||||
|
FusedQKV: *fusedQKV,
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = saveImageArray(img, *out)
|
err = saveImageArray(img, *out)
|
||||||
|
|||||||
@@ -40,10 +40,12 @@ type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo)
|
|||||||
|
|
||||||
// CreateModel imports an image generation model from a directory.
|
// CreateModel imports an image generation model from a directory.
|
||||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||||
|
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
|
||||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||||
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||||
var layers []LayerInfo
|
var layers []LayerInfo
|
||||||
var configLayer LayerInfo
|
var configLayer LayerInfo
|
||||||
|
var totalParams int64 // Count parameters from original tensor shapes
|
||||||
|
|
||||||
// Components to process - extract individual tensors from each
|
// Components to process - extract individual tensors from each
|
||||||
components := []string{"text_encoder", "transformer", "vae"}
|
components := []string{"text_encoder", "transformer", "vae"}
|
||||||
@@ -74,7 +76,11 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
|||||||
}
|
}
|
||||||
|
|
||||||
tensorNames := extractor.ListTensors()
|
tensorNames := extractor.ListTensors()
|
||||||
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
|
quantizeMsg := ""
|
||||||
|
if quantize == "fp8" && component != "vae" {
|
||||||
|
quantizeMsg = ", quantizing to fp8"
|
||||||
|
}
|
||||||
|
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||||
|
|
||||||
for _, tensorName := range tensorNames {
|
for _, tensorName := range tensorNames {
|
||||||
td, err := extractor.GetTensor(tensorName)
|
td, err := extractor.GetTensor(tensorName)
|
||||||
@@ -83,16 +89,30 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
|||||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Count parameters from original tensor shape
|
||||||
|
if len(td.Shape) > 0 {
|
||||||
|
numElements := int64(1)
|
||||||
|
for _, dim := range td.Shape {
|
||||||
|
numElements *= int64(dim)
|
||||||
|
}
|
||||||
|
totalParams += numElements
|
||||||
|
}
|
||||||
|
|
||||||
// Store as minimal safetensors format (88 bytes header overhead)
|
// Store as minimal safetensors format (88 bytes header overhead)
|
||||||
// This enables native mmap loading via mlx_load_safetensors
|
// This enables native mmap loading via mlx_load_safetensors
|
||||||
// Use path-style name: "component/tensor_name"
|
// Use path-style name: "component/tensor_name"
|
||||||
fullName := component + "/" + tensorName
|
fullName := component + "/" + tensorName
|
||||||
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
|
|
||||||
|
// Determine if this tensor should be quantized
|
||||||
|
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
|
||||||
|
|
||||||
|
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||||
|
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
extractor.Close()
|
extractor.Close()
|
||||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||||
}
|
}
|
||||||
layers = append(layers, layer)
|
layers = append(layers, newLayers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
extractor.Close()
|
extractor.Close()
|
||||||
@@ -122,7 +142,7 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
|||||||
|
|
||||||
var r io.Reader
|
var r io.Reader
|
||||||
|
|
||||||
// For model_index.json, normalize to Ollama format
|
// For model_index.json, normalize to Ollama format and add metadata
|
||||||
if cfgPath == "model_index.json" {
|
if cfgPath == "model_index.json" {
|
||||||
data, err := os.ReadFile(fullPath)
|
data, err := os.ReadFile(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -141,6 +161,16 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
|
|||||||
}
|
}
|
||||||
delete(cfg, "_diffusers_version")
|
delete(cfg, "_diffusers_version")
|
||||||
|
|
||||||
|
// Add parameter count (counted from tensor shapes during import)
|
||||||
|
cfg["parameter_count"] = totalParams
|
||||||
|
|
||||||
|
// Add quantization info
|
||||||
|
if quantize == "fp8" {
|
||||||
|
cfg["quantization"] = "FP8"
|
||||||
|
} else {
|
||||||
|
cfg["quantization"] = "BF16"
|
||||||
|
}
|
||||||
|
|
||||||
data, err = json.MarshalIndent(cfg, "", " ")
|
data, err = json.MarshalIndent(cfg, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
||||||
|
|||||||
@@ -60,9 +60,12 @@ func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Transform to [H, W, C] for image conversion
|
// Transform to [H, W, C] for image conversion
|
||||||
img := mlx.Squeeze(arr, 0)
|
// Free intermediate arrays to avoid memory leak
|
||||||
img = mlx.Transpose(img, 1, 2, 0)
|
squeezed := mlx.Squeeze(arr, 0)
|
||||||
img = mlx.Contiguous(img)
|
transposed := mlx.Transpose(squeezed, 1, 2, 0)
|
||||||
|
squeezed.Free()
|
||||||
|
img := mlx.Contiguous(transposed)
|
||||||
|
transposed.Free()
|
||||||
mlx.Eval(img)
|
mlx.Eval(img)
|
||||||
|
|
||||||
imgShape := img.Shape()
|
imgShape := img.Shape()
|
||||||
|
|||||||
@@ -607,6 +607,11 @@ func (a *Array) Valid() bool {
|
|||||||
return a != nil && a.c.ctx != nil
|
return a != nil && a.c.ctx != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Kept returns true if the array is marked to survive Eval() cleanup.
|
||||||
|
func (a *Array) Kept() bool {
|
||||||
|
return a != nil && a.kept
|
||||||
|
}
|
||||||
|
|
||||||
func int32ToCInt(s []int32) *C.int {
|
func int32ToCInt(s []int32) *C.int {
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -1480,6 +1485,44 @@ func (a *Array) ItemInt32() int32 {
|
|||||||
return int32(val)
|
return int32(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bytes copies the raw bytes out of the array without type conversion.
|
||||||
|
// Works with common dtypes (float32, int32, uint32, uint8).
|
||||||
|
// For non-contiguous arrays, call Contiguous() first.
|
||||||
|
// Note: Triggers cleanup of non-kept arrays.
|
||||||
|
func (a *Array) Bytes() []byte {
|
||||||
|
cleanup()
|
||||||
|
nbytes := a.Nbytes()
|
||||||
|
if nbytes == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get raw pointer based on dtype
|
||||||
|
var ptr unsafe.Pointer
|
||||||
|
switch a.Dtype() {
|
||||||
|
case DtypeFloat32:
|
||||||
|
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
|
||||||
|
case DtypeInt32:
|
||||||
|
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
|
||||||
|
case DtypeUint32:
|
||||||
|
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
|
||||||
|
case DtypeUint8:
|
||||||
|
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
|
||||||
|
default:
|
||||||
|
// For other types (bf16, f16, etc), convert to float32
|
||||||
|
arr := AsType(a, DtypeFloat32)
|
||||||
|
arr.Eval()
|
||||||
|
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
|
||||||
|
nbytes = arr.Nbytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
if ptr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data := make([]byte, nbytes)
|
||||||
|
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
// ============ Utility ============
|
// ============ Utility ============
|
||||||
|
|
||||||
// String returns a string representation
|
// String returns a string representation
|
||||||
@@ -1658,6 +1701,34 @@ func (s *SafetensorsFile) Free() {
|
|||||||
C.mlx_map_string_to_string_free(s.metadata)
|
C.mlx_map_string_to_string_free(s.metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
|
||||||
|
// This correctly handles all dtypes including uint32 for quantized weights.
|
||||||
|
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||||
|
cPath := C.CString(path)
|
||||||
|
defer C.free(unsafe.Pointer(cPath))
|
||||||
|
|
||||||
|
// Create the map
|
||||||
|
cArrays := C.mlx_map_string_to_array_new()
|
||||||
|
defer C.mlx_map_string_to_array_free(cArrays)
|
||||||
|
|
||||||
|
// Add each array to the map
|
||||||
|
for name, arr := range arrays {
|
||||||
|
cName := C.CString(name)
|
||||||
|
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
|
||||||
|
C.free(unsafe.Pointer(cName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty metadata (optional)
|
||||||
|
cMeta := C.mlx_map_string_to_string_new()
|
||||||
|
defer C.mlx_map_string_to_string_free(cMeta)
|
||||||
|
|
||||||
|
// Save
|
||||||
|
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
|
||||||
|
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ============ NPY Loading ============
|
// ============ NPY Loading ============
|
||||||
|
|
||||||
// LoadNpy loads a numpy array from an npy file
|
// LoadNpy loads a numpy array from an npy file
|
||||||
@@ -1986,7 +2057,8 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
|
|||||||
// Returns (quantized_weights, scales, biases).
|
// Returns (quantized_weights, scales, biases).
|
||||||
// groupSize: number of elements quantized together (default 64)
|
// groupSize: number of elements quantized together (default 64)
|
||||||
// bits: bits per element, 2, 4, or 8 (default 4)
|
// bits: bits per element, 2, 4, or 8 (default 4)
|
||||||
// mode: "affine" (default) or "mxfp4"
|
// mode: "affine" (default), "mxfp4", or "mxfp8"
|
||||||
|
// Note: mxfp8 mode returns nil biases (only weights and scales)
|
||||||
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
||||||
cMode := C.CString(mode)
|
cMode := C.CString(mode)
|
||||||
defer C.free(unsafe.Pointer(cMode))
|
defer C.free(unsafe.Pointer(cMode))
|
||||||
@@ -1995,14 +2067,21 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
|||||||
res := C.mlx_vector_array_new()
|
res := C.mlx_vector_array_new()
|
||||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
||||||
|
|
||||||
// Result is a vector of 3 arrays: [weights, scales, biases]
|
// Result is a vector of arrays: [weights, scales, biases?]
|
||||||
|
// mxfp8 mode returns only 2 elements (no biases)
|
||||||
|
vecSize := int(C.mlx_vector_array_size(res))
|
||||||
var w0, w1, w2 C.mlx_array
|
var w0, w1, w2 C.mlx_array
|
||||||
C.mlx_vector_array_get(&w0, res, 0)
|
C.mlx_vector_array_get(&w0, res, 0)
|
||||||
C.mlx_vector_array_get(&w1, res, 1)
|
C.mlx_vector_array_get(&w1, res, 1)
|
||||||
C.mlx_vector_array_get(&w2, res, 2)
|
if vecSize >= 3 {
|
||||||
|
C.mlx_vector_array_get(&w2, res, 2)
|
||||||
|
}
|
||||||
C.mlx_vector_array_free(res)
|
C.mlx_vector_array_free(res)
|
||||||
|
|
||||||
return newArray(w0), newArray(w1), newArray(w2)
|
if vecSize >= 3 {
|
||||||
|
return newArray(w0), newArray(w1), newArray(w2)
|
||||||
|
}
|
||||||
|
return newArray(w0), newArray(w1), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dequantize reconstructs weights from quantized form.
|
// Dequantize reconstructs weights from quantized form.
|
||||||
|
|||||||
@@ -222,6 +222,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
|||||||
mlx.Keep(posEmb, negEmb)
|
mlx.Keep(posEmb, negEmb)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||||
|
var batchedEmb *mlx.Array
|
||||||
|
if useCFG {
|
||||||
|
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||||
|
mlx.Keep(batchedEmb)
|
||||||
|
mlx.Eval(batchedEmb)
|
||||||
|
}
|
||||||
|
|
||||||
// Scheduler
|
// Scheduler
|
||||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||||
@@ -264,10 +272,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
|||||||
|
|
||||||
var output *mlx.Array
|
var output *mlx.Array
|
||||||
if useCFG {
|
if useCFG {
|
||||||
// True CFG: run twice and combine with norm rescaling
|
// CFG Batching: single forward pass with batch=2
|
||||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||||
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||||
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||||
|
|
||||||
|
// Single batched forward pass
|
||||||
|
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||||
|
|
||||||
|
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||||
|
L := batchedOutput.Shape()[1]
|
||||||
|
D := batchedOutput.Shape()[2]
|
||||||
|
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||||
|
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||||
|
|
||||||
diff := mlx.Sub(posOutput, negOutput)
|
diff := mlx.Sub(posOutput, negOutput)
|
||||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||||
@@ -305,6 +322,9 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
|||||||
if negEmb != nil {
|
if negEmb != nil {
|
||||||
negEmb.Free()
|
negEmb.Free()
|
||||||
}
|
}
|
||||||
|
if batchedEmb != nil {
|
||||||
|
batchedEmb.Free()
|
||||||
|
}
|
||||||
ropeCache.ImgFreqs.Free()
|
ropeCache.ImgFreqs.Free()
|
||||||
ropeCache.TxtFreqs.Free()
|
ropeCache.TxtFreqs.Free()
|
||||||
if stepCache != nil {
|
if stepCache != nil {
|
||||||
|
|||||||
@@ -241,6 +241,14 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
mlx.Eval(posEmb, negEmb)
|
mlx.Eval(posEmb, negEmb)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre-compute batched embeddings for CFG (single forward pass optimization)
|
||||||
|
var batchedEmb *mlx.Array
|
||||||
|
if useCFG {
|
||||||
|
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||||
|
mlx.Keep(batchedEmb)
|
||||||
|
mlx.Eval(batchedEmb)
|
||||||
|
}
|
||||||
|
|
||||||
// Encode all input images to latents and concatenate
|
// Encode all input images to latents and concatenate
|
||||||
fmt.Println("Encoding images to latents...")
|
fmt.Println("Encoding images to latents...")
|
||||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||||
@@ -291,11 +299,18 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
|
|
||||||
var output *mlx.Array
|
var output *mlx.Array
|
||||||
if useCFG {
|
if useCFG {
|
||||||
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
// CFG Batching: single forward pass with batch=2
|
||||||
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
// Tile inputs: [1, L, D] -> [2, L, D]
|
||||||
|
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
|
||||||
|
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||||
|
|
||||||
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
|
// Single batched forward pass
|
||||||
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
|
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||||
|
|
||||||
|
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||||
|
D := batchedOutput.Shape()[2]
|
||||||
|
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
|
||||||
|
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
|
||||||
|
|
||||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||||
} else {
|
} else {
|
||||||
@@ -317,6 +332,9 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
if negEmb != nil {
|
if negEmb != nil {
|
||||||
negEmb.Free()
|
negEmb.Free()
|
||||||
}
|
}
|
||||||
|
if batchedEmb != nil {
|
||||||
|
batchedEmb.Free()
|
||||||
|
}
|
||||||
ropeCache.ImgFreqs.Free()
|
ropeCache.ImgFreqs.Free()
|
||||||
ropeCache.TxtFreqs.Free()
|
ropeCache.TxtFreqs.Free()
|
||||||
imageLatentsPacked.Free()
|
imageLatentsPacked.Free()
|
||||||
|
|||||||
@@ -28,12 +28,12 @@ type Qwen3Config struct {
|
|||||||
|
|
||||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||||
type Qwen3Attention struct {
|
type Qwen3Attention struct {
|
||||||
QProj *nn.Linear `weight:"q_proj"`
|
QProj nn.LinearLayer `weight:"q_proj"`
|
||||||
KProj *nn.Linear `weight:"k_proj"`
|
KProj nn.LinearLayer `weight:"k_proj"`
|
||||||
VProj *nn.Linear `weight:"v_proj"`
|
VProj nn.LinearLayer `weight:"v_proj"`
|
||||||
OProj *nn.Linear `weight:"o_proj"`
|
OProj nn.LinearLayer `weight:"o_proj"`
|
||||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||||
// Computed fields
|
// Computed fields
|
||||||
NHeads int32
|
NHeads int32
|
||||||
NKVHeads int32
|
NKVHeads int32
|
||||||
@@ -136,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
|||||||
|
|
||||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||||
type Qwen3MLP struct {
|
type Qwen3MLP struct {
|
||||||
GateProj *nn.Linear `weight:"gate_proj"`
|
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||||
UpProj *nn.Linear `weight:"up_proj"`
|
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||||
DownProj *nn.Linear `weight:"down_proj"`
|
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward applies the MLP
|
// Forward applies the MLP
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ type TransformerConfig struct {
|
|||||||
// TimestepEmbedder creates sinusoidal timestep embeddings
|
// TimestepEmbedder creates sinusoidal timestep embeddings
|
||||||
// Output dimension is 256 (fixed), used for AdaLN modulation
|
// Output dimension is 256 (fixed), used for AdaLN modulation
|
||||||
type TimestepEmbedder struct {
|
type TimestepEmbedder struct {
|
||||||
Linear1 *nn.Linear `weight:"mlp.0"`
|
Linear1 nn.LinearLayer `weight:"mlp.0"`
|
||||||
Linear2 *nn.Linear `weight:"mlp.2"`
|
Linear2 nn.LinearLayer `weight:"mlp.2"`
|
||||||
FreqEmbedSize int32 // 256 (computed)
|
FreqEmbedSize int32 // 256 (computed)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
|||||||
|
|
||||||
// XEmbedder embeds image patches to model dimension
|
// XEmbedder embeds image patches to model dimension
|
||||||
type XEmbedder struct {
|
type XEmbedder struct {
|
||||||
Linear *nn.Linear `weight:"2-1"`
|
Linear nn.LinearLayer `weight:"2-1"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward embeds patchified image latents
|
// Forward embeds patchified image latents
|
||||||
@@ -86,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
|||||||
// CapEmbedder projects caption features to model dimension
|
// CapEmbedder projects caption features to model dimension
|
||||||
type CapEmbedder struct {
|
type CapEmbedder struct {
|
||||||
Norm *nn.RMSNorm `weight:"0"`
|
Norm *nn.RMSNorm `weight:"0"`
|
||||||
Linear *nn.Linear `weight:"1"`
|
Linear nn.LinearLayer `weight:"1"`
|
||||||
PadToken *mlx.Array // loaded separately at root level
|
PadToken *mlx.Array // loaded separately at root level
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,12 +100,13 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
|
|||||||
|
|
||||||
// FeedForward implements SwiGLU FFN
|
// FeedForward implements SwiGLU FFN
|
||||||
type FeedForward struct {
|
type FeedForward struct {
|
||||||
W1 *nn.Linear `weight:"w1"` // gate projection
|
W1 nn.LinearLayer `weight:"w1"` // gate projection
|
||||||
W2 *nn.Linear `weight:"w2"` // down projection
|
W2 nn.LinearLayer `weight:"w2"` // down projection
|
||||||
W3 *nn.Linear `weight:"w3"` // up projection
|
W3 nn.LinearLayer `weight:"w3"` // up projection
|
||||||
OutDim int32 // computed from W2
|
OutDim int32 // computed from W2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||||
shape := x.Shape()
|
shape := x.Shape()
|
||||||
@@ -115,6 +116,7 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
|||||||
|
|
||||||
// Reshape for matmul
|
// Reshape for matmul
|
||||||
x = mlx.Reshape(x, B*L, D)
|
x = mlx.Reshape(x, B*L, D)
|
||||||
|
|
||||||
gate := ff.W1.Forward(x)
|
gate := ff.W1.Forward(x)
|
||||||
gate = mlx.SiLU(gate)
|
gate = mlx.SiLU(gate)
|
||||||
up := ff.W3.Forward(x)
|
up := ff.W3.Forward(x)
|
||||||
@@ -126,17 +128,69 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
|||||||
|
|
||||||
// Attention implements multi-head attention with QK norm
|
// Attention implements multi-head attention with QK norm
|
||||||
type Attention struct {
|
type Attention struct {
|
||||||
ToQ *nn.Linear `weight:"to_q"`
|
ToQ nn.LinearLayer `weight:"to_q"`
|
||||||
ToK *nn.Linear `weight:"to_k"`
|
ToK nn.LinearLayer `weight:"to_k"`
|
||||||
ToV *nn.Linear `weight:"to_v"`
|
ToV nn.LinearLayer `weight:"to_v"`
|
||||||
ToOut *nn.Linear `weight:"to_out.0"`
|
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||||
// Computed fields
|
// Fused QKV (computed at init time for efficiency, not loaded from weights)
|
||||||
NHeads int32
|
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
|
||||||
HeadDim int32
|
Fused bool `weight:"-"` // Whether to use fused QKV path
|
||||||
Dim int32
|
// Computed fields (not loaded from weights)
|
||||||
Scale float32
|
NHeads int32 `weight:"-"`
|
||||||
|
HeadDim int32 `weight:"-"`
|
||||||
|
Dim int32 `weight:"-"`
|
||||||
|
Scale float32 `weight:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FuseQKV creates a fused QKV projection by concatenating weights.
|
||||||
|
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
|
||||||
|
// Note: Fusion is skipped for quantized weights as it would require complex
|
||||||
|
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
|
||||||
|
// the ~5% fusion benefit.
|
||||||
|
func (attn *Attention) FuseQKV() {
|
||||||
|
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip fusion for quantized weights - type assert to check
|
||||||
|
toQ, qOk := attn.ToQ.(*nn.Linear)
|
||||||
|
toK, kOk := attn.ToK.(*nn.Linear)
|
||||||
|
toV, vOk := attn.ToV.(*nn.Linear)
|
||||||
|
if !qOk || !kOk || !vOk {
|
||||||
|
// One or more are QuantizedLinear, skip fusion
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
|
||||||
|
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
|
||||||
|
qWeight := toQ.Weight
|
||||||
|
kWeight := toK.Weight
|
||||||
|
vWeight := toV.Weight
|
||||||
|
|
||||||
|
// Concatenate along output dimension (axis 0)
|
||||||
|
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
|
||||||
|
|
||||||
|
// Evaluate fused weight to ensure it's materialized
|
||||||
|
mlx.Eval(fusedWeight)
|
||||||
|
|
||||||
|
// Create fused linear layer
|
||||||
|
fusedLinear := &nn.Linear{Weight: fusedWeight}
|
||||||
|
|
||||||
|
// Handle bias if present
|
||||||
|
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
|
||||||
|
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
|
||||||
|
mlx.Eval(fusedBias)
|
||||||
|
fusedLinear.Bias = fusedBias
|
||||||
|
}
|
||||||
|
|
||||||
|
attn.ToQKV = fusedLinear
|
||||||
|
attn.Fused = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward computes attention
|
// Forward computes attention
|
||||||
@@ -146,11 +200,24 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
|||||||
L := shape[1]
|
L := shape[1]
|
||||||
D := shape[2]
|
D := shape[2]
|
||||||
|
|
||||||
// Project Q, K, V
|
|
||||||
xFlat := mlx.Reshape(x, B*L, D)
|
xFlat := mlx.Reshape(x, B*L, D)
|
||||||
q := attn.ToQ.Forward(xFlat)
|
|
||||||
k := attn.ToK.Forward(xFlat)
|
var q, k, v *mlx.Array
|
||||||
v := attn.ToV.Forward(xFlat)
|
if attn.Fused && attn.ToQKV != nil {
|
||||||
|
// Fused QKV path: single matmul then split
|
||||||
|
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
|
||||||
|
|
||||||
|
// Split into Q, K, V along last dimension
|
||||||
|
// Each has shape [B*L, dim]
|
||||||
|
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
|
||||||
|
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
|
||||||
|
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
|
||||||
|
} else {
|
||||||
|
// Separate Q, K, V projections
|
||||||
|
q = attn.ToQ.Forward(xFlat)
|
||||||
|
k = attn.ToK.Forward(xFlat)
|
||||||
|
v = attn.ToV.Forward(xFlat)
|
||||||
|
}
|
||||||
|
|
||||||
// Reshape to [B, L, nheads, head_dim]
|
// Reshape to [B, L, nheads, head_dim]
|
||||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||||
@@ -227,7 +294,7 @@ type TransformerBlock struct {
|
|||||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||||
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
|
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||||
// Computed fields
|
// Computed fields
|
||||||
HasModulation bool
|
HasModulation bool
|
||||||
Dim int32
|
Dim int32
|
||||||
@@ -281,8 +348,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
|
|||||||
|
|
||||||
// FinalLayer outputs the denoised patches
|
// FinalLayer outputs the denoised patches
|
||||||
type FinalLayer struct {
|
type FinalLayer struct {
|
||||||
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||||
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
|
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
|
||||||
OutDim int32 // computed from Output
|
OutDim int32 // computed from Output
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -350,12 +417,11 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
|
|||||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||||
|
|
||||||
// Load weights from tensor blobs with BF16 conversion
|
|
||||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("weights: %w", err)
|
return fmt.Errorf("weights: %w", err)
|
||||||
}
|
}
|
||||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
if err := weights.Load(0); err != nil {
|
||||||
return fmt.Errorf("load weights: %w", err)
|
return fmt.Errorf("load weights: %w", err)
|
||||||
}
|
}
|
||||||
defer weights.ReleaseAll()
|
defer weights.ReleaseAll()
|
||||||
@@ -377,7 +443,7 @@ func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
|
|||||||
func (m *Transformer) initComputedFields() {
|
func (m *Transformer) initComputedFields() {
|
||||||
cfg := m.TransformerConfig
|
cfg := m.TransformerConfig
|
||||||
m.TEmbed.FreqEmbedSize = 256
|
m.TEmbed.FreqEmbedSize = 256
|
||||||
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
|
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
|
||||||
m.CapEmbed.Norm.Eps = 1e-6
|
m.CapEmbed.Norm.Eps = 1e-6
|
||||||
|
|
||||||
for _, block := range m.NoiseRefiners {
|
for _, block := range m.NoiseRefiners {
|
||||||
@@ -391,6 +457,20 @@ func (m *Transformer) initComputedFields() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
|
||||||
|
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
|
||||||
|
func (m *Transformer) FuseAllQKV() {
|
||||||
|
for _, block := range m.NoiseRefiners {
|
||||||
|
block.Attention.FuseQKV()
|
||||||
|
}
|
||||||
|
for _, block := range m.ContextRefiners {
|
||||||
|
block.Attention.FuseQKV()
|
||||||
|
}
|
||||||
|
for _, block := range m.Layers {
|
||||||
|
block.Attention.FuseQKV()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// initTransformerBlock sets computed fields on a transformer block
|
// initTransformerBlock sets computed fields on a transformer block
|
||||||
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||||
block.Dim = cfg.Dim
|
block.Dim = cfg.Dim
|
||||||
@@ -404,7 +484,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
|||||||
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
|
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
|
||||||
|
|
||||||
// Init feedforward OutDim
|
// Init feedforward OutDim
|
||||||
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
|
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
|
||||||
|
|
||||||
// Set eps on all RMSNorm layers
|
// Set eps on all RMSNorm layers
|
||||||
block.AttentionNorm1.Eps = cfg.NormEps
|
block.AttentionNorm1.Eps = cfg.NormEps
|
||||||
@@ -423,6 +503,8 @@ type RoPECache struct {
|
|||||||
UnifiedSin *mlx.Array
|
UnifiedSin *mlx.Array
|
||||||
ImgLen int32
|
ImgLen int32
|
||||||
CapLen int32
|
CapLen int32
|
||||||
|
GridH int32 // Image token grid height
|
||||||
|
GridW int32 // Image token grid width
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
|
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
|
||||||
@@ -456,6 +538,8 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
|
|||||||
UnifiedSin: unifiedSin,
|
UnifiedSin: unifiedSin,
|
||||||
ImgLen: imgLen,
|
ImgLen: imgLen,
|
||||||
CapLen: capLen,
|
CapLen: capLen,
|
||||||
|
GridH: hTok,
|
||||||
|
GridW: wTok,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -104,6 +104,8 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
|
|||||||
groupSize := C / gn.NumGroups
|
groupSize := C / gn.NumGroups
|
||||||
|
|
||||||
// Keep the input - we need it for slicing tiles later
|
// Keep the input - we need it for slicing tiles later
|
||||||
|
// Track if we were the ones who kept it, so we can restore state after
|
||||||
|
wasKept := x.Kept()
|
||||||
mlx.Keep(x)
|
mlx.Keep(x)
|
||||||
|
|
||||||
// Compute per-group mean and variance using flattened spatial dimensions
|
// Compute per-group mean and variance using flattened spatial dimensions
|
||||||
@@ -205,6 +207,10 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clean up kept arrays
|
// Clean up kept arrays
|
||||||
|
// Restore x's kept state - only free if we were the ones who kept it
|
||||||
|
if !wasKept {
|
||||||
|
x.Free()
|
||||||
|
}
|
||||||
mean.Free()
|
mean.Free()
|
||||||
invStd.Free()
|
invStd.Free()
|
||||||
if weightGN != nil {
|
if weightGN != nil {
|
||||||
@@ -734,18 +740,26 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
|||||||
h := vae.ConvIn.Forward(z)
|
h := vae.ConvIn.Forward(z)
|
||||||
mlx.Eval(h)
|
mlx.Eval(h)
|
||||||
|
|
||||||
|
prev := h
|
||||||
h = vae.MidBlock.Forward(h)
|
h = vae.MidBlock.Forward(h)
|
||||||
|
prev.Free()
|
||||||
|
|
||||||
for _, upBlock := range vae.UpBlocks {
|
for _, upBlock := range vae.UpBlocks {
|
||||||
|
prev = h
|
||||||
h = upBlock.Forward(h)
|
h = upBlock.Forward(h)
|
||||||
|
prev.Free()
|
||||||
}
|
}
|
||||||
|
|
||||||
prev := h
|
prev = h
|
||||||
h = vae.ConvNormOut.Forward(h)
|
h = vae.ConvNormOut.Forward(h)
|
||||||
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
||||||
|
prev.Free()
|
||||||
|
|
||||||
|
prev = h
|
||||||
h = mlx.SiLU(h)
|
h = mlx.SiLU(h)
|
||||||
h = vae.ConvOut.Forward(h)
|
h = vae.ConvOut.Forward(h)
|
||||||
mlx.Eval(h)
|
mlx.Eval(h)
|
||||||
|
prev.Free()
|
||||||
|
|
||||||
// VAE outputs [-1, 1], convert to [0, 1]
|
// VAE outputs [-1, 1], convert to [0, 1]
|
||||||
h = mlx.MulScalar(h, 0.5)
|
h = mlx.MulScalar(h, 0.5)
|
||||||
@@ -754,7 +768,6 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
|||||||
|
|
||||||
// Convert NHWC -> NCHW for output
|
// Convert NHWC -> NCHW for output
|
||||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||||
prev.Free()
|
|
||||||
mlx.Eval(h)
|
mlx.Eval(h)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
|
|||||||
@@ -26,10 +26,12 @@ type GenerateConfig struct {
|
|||||||
Progress ProgressFunc // Optional progress callback
|
Progress ProgressFunc // Optional progress callback
|
||||||
CapturePath string // GPU capture path (debug)
|
CapturePath string // GPU capture path (debug)
|
||||||
|
|
||||||
// Layer caching options (speedup via shallow layer reuse)
|
// TeaCache options (timestep embedding aware caching)
|
||||||
LayerCache bool // Enable layer caching (default: false)
|
TeaCache bool // TeaCache is always enabled for faster inference
|
||||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
|
||||||
CacheLayers int // Number of shallow layers to cache (default: 15)
|
|
||||||
|
// Fused QKV (fuse Q/K/V projections into single matmul)
|
||||||
|
FusedQKV bool // Enable fused QKV projection (default: false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProgressFunc is called during generation with step progress.
|
// ProgressFunc is called during generation with step progress.
|
||||||
@@ -42,6 +44,7 @@ type Model struct {
|
|||||||
TextEncoder *Qwen3TextEncoder
|
TextEncoder *Qwen3TextEncoder
|
||||||
Transformer *Transformer
|
Transformer *Transformer
|
||||||
VAEDecoder *VAEDecoder
|
VAEDecoder *VAEDecoder
|
||||||
|
qkvFused bool // Track if QKV has been fused (do only once)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load loads the Z-Image model from ollama blob storage.
|
// Load loads the Z-Image model from ollama blob storage.
|
||||||
@@ -196,13 +199,17 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
if cfg.CFGScale <= 0 {
|
if cfg.CFGScale <= 0 {
|
||||||
cfg.CFGScale = 4.0
|
cfg.CFGScale = 4.0
|
||||||
}
|
}
|
||||||
if cfg.LayerCache {
|
// TeaCache enabled by default
|
||||||
if cfg.CacheInterval <= 0 {
|
cfg.TeaCache = true
|
||||||
cfg.CacheInterval = 3
|
if cfg.TeaCacheThreshold <= 0 {
|
||||||
}
|
cfg.TeaCacheThreshold = 0.15
|
||||||
if cfg.CacheLayers <= 0 {
|
}
|
||||||
cfg.CacheLayers = 15 // Half of 30 layers
|
|
||||||
}
|
// Enable fused QKV if requested (only fuse once)
|
||||||
|
if cfg.FusedQKV && !m.qkvFused {
|
||||||
|
m.Transformer.FuseAllQKV()
|
||||||
|
m.qkvFused = true
|
||||||
|
fmt.Println(" Fused QKV enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
useCFG := cfg.NegativePrompt != ""
|
useCFG := cfg.NegativePrompt != ""
|
||||||
@@ -260,12 +267,54 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
mlx.Eval(ropeCache.UnifiedCos)
|
mlx.Eval(ropeCache.UnifiedCos)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
|
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
|
||||||
var stepCache *cache.StepCache
|
var batchedEmb *mlx.Array
|
||||||
if cfg.LayerCache {
|
if useCFG {
|
||||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
|
||||||
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
|
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
|
||||||
cfg.CacheLayers, cfg.CacheInterval)
|
mlx.Keep(batchedEmb)
|
||||||
|
mlx.Eval(batchedEmb)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TeaCache for timestep-aware caching
|
||||||
|
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
|
||||||
|
var teaCache *cache.TeaCache
|
||||||
|
if cfg.TeaCache {
|
||||||
|
skipEarly := 0
|
||||||
|
if useCFG {
|
||||||
|
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
|
||||||
|
}
|
||||||
|
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
|
||||||
|
Threshold: cfg.TeaCacheThreshold,
|
||||||
|
RescaleFactor: 1.0,
|
||||||
|
SkipEarlySteps: skipEarly,
|
||||||
|
})
|
||||||
|
if useCFG {
|
||||||
|
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup frees all kept arrays when we need to abort early
|
||||||
|
cleanup := func() {
|
||||||
|
posEmb.Free()
|
||||||
|
if negEmb != nil {
|
||||||
|
negEmb.Free()
|
||||||
|
}
|
||||||
|
ropeCache.ImgCos.Free()
|
||||||
|
ropeCache.ImgSin.Free()
|
||||||
|
ropeCache.CapCos.Free()
|
||||||
|
ropeCache.CapSin.Free()
|
||||||
|
ropeCache.UnifiedCos.Free()
|
||||||
|
ropeCache.UnifiedSin.Free()
|
||||||
|
if batchedEmb != nil {
|
||||||
|
batchedEmb.Free()
|
||||||
|
}
|
||||||
|
if teaCache != nil {
|
||||||
|
teaCache.Free()
|
||||||
|
}
|
||||||
|
latents.Free()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Denoising loop
|
// Denoising loop
|
||||||
@@ -277,6 +326,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
cleanup()
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -289,50 +339,77 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
}
|
}
|
||||||
|
|
||||||
tCurr := scheduler.Timesteps[i]
|
tCurr := scheduler.Timesteps[i]
|
||||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
var noisePred *mlx.Array
|
||||||
|
|
||||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
// TeaCache: check if we should compute or reuse cached output
|
||||||
|
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
|
||||||
|
|
||||||
var output *mlx.Array
|
if shouldCompute {
|
||||||
if stepCache != nil {
|
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||||
// Use layer caching for faster inference
|
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||||
|
|
||||||
|
var output *mlx.Array
|
||||||
if useCFG {
|
if useCFG {
|
||||||
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
// CFG Batching: single forward pass with batch=2
|
||||||
stepCache, i, cfg.CacheInterval)
|
// Tile patches: [1, L, D] -> [2, L, D]
|
||||||
// Note: CFG with layer cache shares the cache between pos/neg
|
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
|
||||||
// This is approximate but fast - neg prompt uses same cached shallow layers
|
// Tile timestep: [1] -> [2]
|
||||||
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
|
batchedTimestep := mlx.Tile(timestep, []int32{2})
|
||||||
stepCache, i, cfg.CacheInterval)
|
|
||||||
diff := mlx.Sub(posOutput, negOutput)
|
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
|
||||||
|
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
|
||||||
|
|
||||||
|
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
|
||||||
|
outputShape := batchedOutput.Shape()
|
||||||
|
L := outputShape[1]
|
||||||
|
D := outputShape[2]
|
||||||
|
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
|
||||||
|
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
|
||||||
|
|
||||||
|
// Convert to noise predictions (unpatchify and negate)
|
||||||
|
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||||
|
posPred = mlx.Neg(posPred)
|
||||||
|
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||||
|
negPred = mlx.Neg(negPred)
|
||||||
|
|
||||||
|
// Cache pos/neg separately for TeaCache
|
||||||
|
if teaCache != nil {
|
||||||
|
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
|
||||||
|
mlx.Keep(teaCache.Arrays()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply CFG: noisePred = neg + scale * (pos - neg)
|
||||||
|
diff := mlx.Sub(posPred, negPred)
|
||||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||||
output = mlx.Add(negOutput, scaledDiff)
|
noisePred = mlx.Add(negPred, scaledDiff)
|
||||||
} else {
|
|
||||||
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
|
||||||
stepCache, i, cfg.CacheInterval)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Standard forward without caching
|
|
||||||
if useCFG {
|
|
||||||
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
|
||||||
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
|
|
||||||
diff := mlx.Sub(posOutput, negOutput)
|
|
||||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
|
||||||
output = mlx.Add(negOutput, scaledDiff)
|
|
||||||
} else {
|
} else {
|
||||||
|
// Non-CFG forward pass
|
||||||
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||||
}
|
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||||
}
|
noisePred = mlx.Neg(noisePred)
|
||||||
|
|
||||||
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
// Update TeaCache
|
||||||
noisePred = mlx.Neg(noisePred)
|
if teaCache != nil {
|
||||||
|
teaCache.UpdateCache(noisePred, tCurr)
|
||||||
|
mlx.Keep(teaCache.Arrays()...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
|
||||||
|
// CFG mode: get cached pos/neg and compute CFG fresh
|
||||||
|
posPred, negPred := teaCache.GetCFGCached()
|
||||||
|
diff := mlx.Sub(posPred, negPred)
|
||||||
|
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||||
|
noisePred = mlx.Add(negPred, scaledDiff)
|
||||||
|
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
|
||||||
|
} else {
|
||||||
|
// Non-CFG mode: reuse cached noise prediction
|
||||||
|
noisePred = teaCache.GetCached()
|
||||||
|
fmt.Printf(" [TeaCache: reusing cached output]\n")
|
||||||
|
}
|
||||||
|
|
||||||
oldLatents := latents
|
oldLatents := latents
|
||||||
latents = scheduler.Step(noisePred, latents, i)
|
latents = scheduler.Step(noisePred, latents, i)
|
||||||
|
|
||||||
// Keep latents and any cached arrays
|
|
||||||
if stepCache != nil {
|
|
||||||
mlx.Keep(stepCache.Arrays()...)
|
|
||||||
}
|
|
||||||
mlx.Eval(latents)
|
mlx.Eval(latents)
|
||||||
oldLatents.Free()
|
oldLatents.Free()
|
||||||
|
|
||||||
@@ -361,8 +438,14 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
|||||||
ropeCache.CapSin.Free()
|
ropeCache.CapSin.Free()
|
||||||
ropeCache.UnifiedCos.Free()
|
ropeCache.UnifiedCos.Free()
|
||||||
ropeCache.UnifiedSin.Free()
|
ropeCache.UnifiedSin.Free()
|
||||||
if stepCache != nil {
|
if batchedEmb != nil {
|
||||||
stepCache.Free()
|
batchedEmb.Free()
|
||||||
|
}
|
||||||
|
if teaCache != nil {
|
||||||
|
hits, misses := teaCache.Stats()
|
||||||
|
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
|
||||||
|
hits, misses, float64(hits)/float64(hits+misses)*100)
|
||||||
|
teaCache.Free()
|
||||||
}
|
}
|
||||||
|
|
||||||
// VAE decode
|
// VAE decode
|
||||||
|
|||||||
@@ -10,6 +10,13 @@ type Layer interface {
|
|||||||
Forward(x *mlx.Array) *mlx.Array
|
Forward(x *mlx.Array) *mlx.Array
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LinearLayer is an interface for linear layers (both regular and quantized).
|
||||||
|
// This allows swapping between Linear and QuantizedLinear at runtime.
|
||||||
|
type LinearLayer interface {
|
||||||
|
Forward(x *mlx.Array) *mlx.Array
|
||||||
|
OutputDim() int32 // Returns the output dimension of the layer
|
||||||
|
}
|
||||||
|
|
||||||
// Linear applies an affine transformation: y = x @ W.T + b
|
// Linear applies an affine transformation: y = x @ W.T + b
|
||||||
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
|
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
|
||||||
type Linear struct {
|
type Linear struct {
|
||||||
@@ -49,6 +56,11 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
|||||||
return mlx.Linear(x, w)
|
return mlx.Linear(x, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OutputDim returns the output dimension of the linear layer.
|
||||||
|
func (l *Linear) OutputDim() int32 {
|
||||||
|
return l.Weight.Shape()[0]
|
||||||
|
}
|
||||||
|
|
||||||
// ToQuantized converts this Linear to a QuantizedLinear.
|
// ToQuantized converts this Linear to a QuantizedLinear.
|
||||||
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
|
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
|
||||||
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
|
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
|
||||||
@@ -84,6 +96,13 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OutputDim returns the output dimension of the quantized linear layer.
|
||||||
|
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
|
||||||
|
// The output dimension is the first dimension of the weight.
|
||||||
|
func (ql *QuantizedLinear) OutputDim() int32 {
|
||||||
|
return ql.Weight.Shape()[0]
|
||||||
|
}
|
||||||
|
|
||||||
// RMSNorm represents an RMS normalization layer.
|
// RMSNorm represents an RMS normalization layer.
|
||||||
type RMSNorm struct {
|
type RMSNorm struct {
|
||||||
Weight *mlx.Array `weight:"weight"`
|
Weight *mlx.Array `weight:"weight"`
|
||||||
|
|||||||
22
x/imagegen/quantize.go
Normal file
22
x/imagegen/quantize.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package imagegen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||||
|
// When quantize is true, returns multiple layers (weight + scales + biases).
|
||||||
|
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
|
||||||
|
|
||||||
|
// ShouldQuantize returns true if a tensor should be quantized.
|
||||||
|
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
|
||||||
|
func ShouldQuantize(name, component string) bool {
|
||||||
|
if component == "vae" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.HasSuffix(name, ".weight")
|
||||||
|
}
|
||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -34,7 +33,8 @@ type Request struct {
|
|||||||
|
|
||||||
// Response is streamed back for each progress update
|
// Response is streamed back for each progress update
|
||||||
type Response struct {
|
type Response struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,10 +191,10 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save image
|
// Encode image as base64 PNG
|
||||||
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
|
imageData, err := imagegen.EncodeImageBase64(img)
|
||||||
if err := imagegen.SaveImage(img, outPath); err != nil {
|
if err != nil {
|
||||||
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
|
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||||
data, _ := json.Marshal(resp)
|
data, _ := json.Marshal(resp)
|
||||||
w.Write(data)
|
w.Write(data)
|
||||||
w.Write([]byte("\n"))
|
w.Write([]byte("\n"))
|
||||||
@@ -204,11 +204,12 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Free the generated image array and clean up MLX state
|
// Free the generated image array and clean up MLX state
|
||||||
img.Free()
|
img.Free()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
|
mlx.MetalResetPeakMemory()
|
||||||
|
|
||||||
// Send final response
|
// Send final response with image data
|
||||||
resp := Response{
|
resp := Response{
|
||||||
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
|
Image: imageData,
|
||||||
Done: true,
|
Done: true,
|
||||||
}
|
}
|
||||||
data, _ := json.Marshal(resp)
|
data, _ := json.Marshal(resp)
|
||||||
w.Write(data)
|
w.Write(data)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/nn"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WeightSource is an interface for loading weights.
|
// WeightSource is an interface for loading weights.
|
||||||
@@ -102,6 +103,22 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle nn.LinearLayer interface fields specially
|
||||||
|
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
|
||||||
|
if !hasTag {
|
||||||
|
continue // no tag = skip
|
||||||
|
}
|
||||||
|
layer, err := LoadLinearLayer(weights, fullPath)
|
||||||
|
if err != nil {
|
||||||
|
if !optional {
|
||||||
|
*errs = append(*errs, fullPath+": "+err.Error())
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fieldVal.Set(reflect.ValueOf(layer))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Handle by kind
|
// Handle by kind
|
||||||
switch fieldVal.Kind() {
|
switch fieldVal.Kind() {
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
@@ -176,3 +193,64 @@ func joinPath(prefix, suffix string) string {
|
|||||||
}
|
}
|
||||||
return prefix + "." + suffix
|
return prefix + "." + suffix
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
|
||||||
|
// If {path}.weight_scale exists, dequantizes the weights.
|
||||||
|
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
|
||||||
|
// Check if this is a quantized layer by looking for scale tensor
|
||||||
|
scalePath := path + ".weight_scale"
|
||||||
|
if weights.HasTensor(scalePath) {
|
||||||
|
weight, err := weights.GetTensor(path + ".weight")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
scales, err := weights.GetTensor(scalePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bias is optional
|
||||||
|
var bias *mlx.Array
|
||||||
|
biasPath := path + ".bias"
|
||||||
|
if weights.HasTensor(biasPath) {
|
||||||
|
bias, _ = weights.GetTensor(biasPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
var qbiases *mlx.Array
|
||||||
|
qbiasPath := path + ".weight_qbias"
|
||||||
|
if weights.HasTensor(qbiasPath) {
|
||||||
|
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mlx.MetalIsAvailable() {
|
||||||
|
return &nn.QuantizedLinear{
|
||||||
|
Weight: weight,
|
||||||
|
Scales: scales,
|
||||||
|
QBiases: qbiases,
|
||||||
|
Bias: bias,
|
||||||
|
GroupSize: 32,
|
||||||
|
Bits: 8,
|
||||||
|
Mode: "affine",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
|
||||||
|
return nn.NewLinear(dequantized, bias), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load as regular Linear
|
||||||
|
weight, err := weights.GetTensor(path + ".weight")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bias is optional
|
||||||
|
var bias *mlx.Array
|
||||||
|
biasPath := path + ".bias"
|
||||||
|
if weights.HasTensor(biasPath) {
|
||||||
|
bias, _ = weights.GetTensor(biasPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nn.NewLinear(weight, bias), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -46,7 +48,8 @@ type completionRequest struct {
|
|||||||
|
|
||||||
// completionResponse is received from the subprocess
|
// completionResponse is received from the subprocess
|
||||||
type completionResponse struct {
|
type completionResponse struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,7 +72,7 @@ func NewServer(modelName string) (*Server, error) {
|
|||||||
port = rand.Intn(65535-49152) + 49152
|
port = rand.Intn(65535-49152) + 49152
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the ollama executable path
|
// Get the ollama-mlx executable path (in same directory as current executable)
|
||||||
exe, err := os.Executable()
|
exe, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||||
@@ -77,11 +80,42 @@ func NewServer(modelName string) (*Server, error) {
|
|||||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||||
exe = eval
|
exe = eval
|
||||||
}
|
}
|
||||||
|
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
|
||||||
|
|
||||||
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
|
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
|
||||||
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||||
cmd.Env = os.Environ()
|
cmd.Env = os.Environ()
|
||||||
|
|
||||||
|
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
||||||
|
libraryPaths := []string{ml.LibOllamaPath}
|
||||||
|
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||||
|
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append existing LD_LIBRARY_PATH if set
|
||||||
|
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||||
|
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||||
|
|
||||||
|
// Update or add LD_LIBRARY_PATH in cmd.Env
|
||||||
|
found := false
|
||||||
|
for i := range cmd.Env {
|
||||||
|
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||||
|
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||||
|
}
|
||||||
|
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||||
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
port: port,
|
port: port,
|
||||||
@@ -112,7 +146,7 @@ func NewServer(modelName string) (*Server, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
slog.Info("starting image runner subprocess", "model", modelName, "port", port)
|
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||||
}
|
}
|
||||||
@@ -250,15 +284,23 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream responses
|
// Stream responses - use large buffer for base64 image data
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var cresp completionResponse
|
var cresp completionResponse
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
content := cresp.Content
|
||||||
|
// If this is the final response with an image, encode it in the content
|
||||||
|
if cresp.Done && cresp.Image != "" {
|
||||||
|
content = "IMAGE_BASE64:" + cresp.Image
|
||||||
|
}
|
||||||
|
|
||||||
fn(llm.CompletionResponse{
|
fn(llm.CompletionResponse{
|
||||||
Content: cresp.Content,
|
Content: content,
|
||||||
Done: cresp.Done,
|
Done: cresp.Done,
|
||||||
})
|
})
|
||||||
if cresp.Done {
|
if cresp.Done {
|
||||||
|
|||||||
@@ -45,24 +45,33 @@ func download(ctx context.Context, opts DownloadOptions) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter existing
|
// Calculate total from all blobs (for accurate progress reporting on resume)
|
||||||
var blobs []Blob
|
|
||||||
var total int64
|
var total int64
|
||||||
|
for _, b := range opts.Blobs {
|
||||||
|
total += b.Size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out already-downloaded blobs and track completed bytes
|
||||||
|
var blobs []Blob
|
||||||
|
var alreadyCompleted int64
|
||||||
for _, b := range opts.Blobs {
|
for _, b := range opts.Blobs {
|
||||||
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
|
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
|
||||||
if opts.Logger != nil {
|
if opts.Logger != nil {
|
||||||
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
|
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
|
||||||
}
|
}
|
||||||
|
alreadyCompleted += b.Size
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
blobs = append(blobs, b)
|
blobs = append(blobs, b)
|
||||||
total += b.Size
|
|
||||||
}
|
}
|
||||||
if len(blobs) == 0 {
|
if len(blobs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
token := opts.Token
|
token := opts.Token
|
||||||
|
progress := newProgressTracker(total, opts.Progress)
|
||||||
|
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
|
||||||
|
|
||||||
d := &downloader{
|
d := &downloader{
|
||||||
client: cmp.Or(opts.Client, defaultClient),
|
client: cmp.Or(opts.Client, defaultClient),
|
||||||
baseURL: opts.BaseURL,
|
baseURL: opts.BaseURL,
|
||||||
@@ -72,7 +81,7 @@ func download(ctx context.Context, opts DownloadOptions) error {
|
|||||||
getToken: opts.GetToken,
|
getToken: opts.GetToken,
|
||||||
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
||||||
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
|
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
|
||||||
progress: newProgressTracker(total, opts.Progress),
|
progress: progress,
|
||||||
speeds: &speedTracker{},
|
speeds: &speedTracker{},
|
||||||
logger: opts.Logger,
|
logger: opts.Logger,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -110,8 +110,6 @@ var defaultClient = &http.Client{
|
|||||||
MaxIdleConnsPerHost: 100,
|
MaxIdleConnsPerHost: 100,
|
||||||
IdleConnTimeout: 90 * time.Second,
|
IdleConnTimeout: 90 * time.Second,
|
||||||
},
|
},
|
||||||
Timeout: 5 * time.Minute,
|
|
||||||
// Don't follow redirects automatically - we handle them manually
|
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
return http.ErrUseLastResponse
|
return http.ErrUseLastResponse
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -284,6 +284,83 @@ func TestDownloadSkipsExisting(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDownloadResumeProgressTotal(t *testing.T) {
|
||||||
|
// Test that when resuming a download with some blobs already present:
|
||||||
|
// 1. Total reflects ALL blob sizes (not just remaining)
|
||||||
|
// 2. Completed starts at the size of already-downloaded blobs
|
||||||
|
serverDir := t.TempDir()
|
||||||
|
blob1, data1 := createTestBlob(t, serverDir, 1000)
|
||||||
|
blob2, data2 := createTestBlob(t, serverDir, 2000)
|
||||||
|
blob3, data3 := createTestBlob(t, serverDir, 3000)
|
||||||
|
|
||||||
|
// Pre-populate client with blob1 and blob2 (simulating partial download)
|
||||||
|
clientDir := t.TempDir()
|
||||||
|
for _, b := range []struct {
|
||||||
|
blob Blob
|
||||||
|
data []byte
|
||||||
|
}{{blob1, data1}, {blob2, data2}} {
|
||||||
|
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, b.data, 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
digest := filepath.Base(r.URL.Path)
|
||||||
|
path := filepath.Join(serverDir, digestToPath(digest))
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(data)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var firstCompleted, firstTotal int64
|
||||||
|
var gotFirstProgress bool
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
err := Download(context.Background(), DownloadOptions{
|
||||||
|
Blobs: []Blob{blob1, blob2, blob3},
|
||||||
|
BaseURL: server.URL,
|
||||||
|
DestDir: clientDir,
|
||||||
|
Concurrency: 1,
|
||||||
|
Progress: func(completed, total int64) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if !gotFirstProgress {
|
||||||
|
firstCompleted = completed
|
||||||
|
firstTotal = total
|
||||||
|
gotFirstProgress = true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Download failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Total should be sum of ALL blobs, not just blob3
|
||||||
|
expectedTotal := blob1.Size + blob2.Size + blob3.Size
|
||||||
|
if firstTotal != expectedTotal {
|
||||||
|
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First progress call should show already-completed bytes from blob1+blob2
|
||||||
|
expectedCompleted := blob1.Size + blob2.Size
|
||||||
|
if firstCompleted < expectedCompleted {
|
||||||
|
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify blob3 was downloaded
|
||||||
|
verifyBlob(t, clientDir, blob3, data3)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDownloadDigestMismatch(t *testing.T) {
|
func TestDownloadDigestMismatch(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Return wrong data
|
// Return wrong data
|
||||||
|
|||||||
@@ -54,6 +54,16 @@ func (r *Registry) RegisterBash() {
|
|||||||
r.Register(&BashTool{})
|
r.Register(&BashTool{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterWebSearch adds the web search tool to the registry.
|
||||||
|
func (r *Registry) RegisterWebSearch() {
|
||||||
|
r.Register(&WebSearchTool{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterWebFetch adds the web fetch tool to the registry.
|
||||||
|
func (r *Registry) RegisterWebFetch() {
|
||||||
|
r.Register(&WebFetchTool{})
|
||||||
|
}
|
||||||
|
|
||||||
// Get retrieves a tool by name.
|
// Get retrieves a tool by name.
|
||||||
func (r *Registry) Get(name string) (Tool, bool) {
|
func (r *Registry) Get(name string) (Tool, bool) {
|
||||||
tool, ok := r.tools[name]
|
tool, ok := r.tools[name]
|
||||||
|
|||||||
162
x/tools/webfetch.go
Normal file
162
x/tools/webfetch.go
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
webFetchAPI = "https://ollama.com/api/web_fetch"
|
||||||
|
webFetchTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrWebFetchAuthRequired is returned when web fetch requires authentication
|
||||||
|
var ErrWebFetchAuthRequired = errors.New("web fetch requires authentication")
|
||||||
|
|
||||||
|
// WebFetchTool implements web page fetching using Ollama's hosted API.
|
||||||
|
type WebFetchTool struct{}
|
||||||
|
|
||||||
|
// Name returns the tool name.
|
||||||
|
func (w *WebFetchTool) Name() string {
|
||||||
|
return "web_fetch"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description returns a description of the tool.
|
||||||
|
func (w *WebFetchTool) Description() string {
|
||||||
|
return "Fetch and extract text content from a web page. Use this to read the full content of a URL found in search results or provided by the user."
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema returns the tool's parameter schema.
|
||||||
|
func (w *WebFetchTool) Schema() api.ToolFunction {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("url", api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Description: "The URL to fetch and extract content from",
|
||||||
|
})
|
||||||
|
return api.ToolFunction{
|
||||||
|
Name: w.Name(),
|
||||||
|
Description: w.Description(),
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: props,
|
||||||
|
Required: []string{"url"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// webFetchRequest is the request body for the web fetch API.
|
||||||
|
type webFetchRequest struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// webFetchResponse is the response from the web fetch API.
|
||||||
|
type webFetchResponse struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Links []string `json:"links,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute fetches content from a web page.
|
||||||
|
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
|
||||||
|
func (w *WebFetchTool) Execute(args map[string]any) (string, error) {
|
||||||
|
urlStr, ok := args["url"].(string)
|
||||||
|
if !ok || urlStr == "" {
|
||||||
|
return "", fmt.Errorf("url parameter is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate URL
|
||||||
|
if _, err := url.Parse(urlStr); err != nil {
|
||||||
|
return "", fmt.Errorf("invalid URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare request
|
||||||
|
reqBody := webFetchRequest{
|
||||||
|
URL: urlStr,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBody, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshaling request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse URL and add timestamp for signing
|
||||||
|
fetchURL, err := url.Parse(webFetchAPI)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parsing fetch URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
q := fetchURL.Query()
|
||||||
|
q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
fetchURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
// Sign the request using Ollama key (~/.ollama/id_ed25519)
|
||||||
|
ctx := context.Background()
|
||||||
|
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, fetchURL.RequestURI())
|
||||||
|
signature, err := auth.Sign(ctx, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("signing request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fetchURL.String(), bytes.NewBuffer(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if signature != "" {
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
client := &http.Client{Timeout: webFetchTimeout}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("sending request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("reading response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
return "", ErrWebFetchAuthRequired
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("web fetch API returned status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response
|
||||||
|
var fetchResp webFetchResponse
|
||||||
|
if err := json.Unmarshal(body, &fetchResp); err != nil {
|
||||||
|
return "", fmt.Errorf("parsing response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format result
|
||||||
|
var sb strings.Builder
|
||||||
|
if fetchResp.Title != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("Title: %s\n\n", fetchResp.Title))
|
||||||
|
}
|
||||||
|
|
||||||
|
if fetchResp.Content != "" {
|
||||||
|
sb.WriteString("Content:\n")
|
||||||
|
sb.WriteString(fetchResp.Content)
|
||||||
|
} else {
|
||||||
|
sb.WriteString("No content could be extracted from the page.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user