mirror of
https://github.com/ollama/ollama.git
synced 2026-01-16 19:41:24 -05:00
Compare commits
1 Commits
main
...
imagegen-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca1faae6a2 |
18
Dockerfile
18
Dockerfile
@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
# install epel-release for ccache
|
||||
RUN yum install -y yum-utils epel-release \
|
||||
&& dnf install -y clang ccache git \
|
||||
&& dnf install -y clang ccache \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
|
||||
ENV CC=clang CXX=clang++
|
||||
|
||||
@@ -149,7 +149,6 @@ COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
@@ -157,6 +156,14 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN mkdir -p dist/bin
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -165,14 +172,12 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -180,6 +185,7 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
v0.4.1
|
||||
@@ -270,10 +270,10 @@ cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
|
||||
```shell
|
||||
go build -tags mlx .
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
21
api/types.go
21
api/types.go
@@ -97,6 +97,15 @@ type GenerateRequest struct {
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Width is the width of the generated image (for image generation models).
|
||||
Width int32 `json:"width,omitempty"`
|
||||
|
||||
// Height is the height of the generated image (for image generation models).
|
||||
Height int32 `json:"height,omitempty"`
|
||||
|
||||
// Steps is the number of diffusion steps (for image generation models).
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]any `json:"options"`
|
||||
@@ -860,6 +869,18 @@ type GenerateResponse struct {
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Status describes the current phase of generation (e.g., "generating image").
|
||||
Status string `json:"status,omitempty"`
|
||||
|
||||
// Total is the total count for the current phase (e.g., total steps).
|
||||
Total int64 `json:"total,omitempty"`
|
||||
|
||||
// Completed is the completed count for the current phase.
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
|
||||
// Images contains base64-encoded generated images for image generation models.
|
||||
Images []string `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
||||
13
docs/api.md
13
docs/api.md
@@ -47,6 +47,12 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
Image generation parameters (for image generation models):
|
||||
|
||||
- `width`: (optional) width of the generated image in pixels (default: model-specific)
|
||||
- `height`: (optional) height of the generated image in pixels (default: model-specific)
|
||||
- `steps`: (optional) number of diffusion steps (default: model-specific)
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
||||
@@ -106,6 +112,13 @@ The final response in the stream also includes additional data about the generat
|
||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
||||
|
||||
For image generation models, the response includes additional fields:
|
||||
|
||||
- `status`: describes the current phase (e.g., "generating image")
|
||||
- `total`: total count for the current phase (e.g., total steps)
|
||||
- `completed`: completed count for the current phase
|
||||
- `images`: array of base64-encoded generated images (in final response)
|
||||
|
||||
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` \* `10^9`.
|
||||
|
||||
```json
|
||||
|
||||
@@ -1468,6 +1468,7 @@ type CompletionRequest struct {
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1518,10 +1519,14 @@ type CompletionResponse struct {
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image generation fields
|
||||
Image []byte `json:"image,omitempty"` // Generated image
|
||||
Step int `json:"step,omitempty"` // Current generation step
|
||||
Total int `json:"total,omitempty"` // Total generation steps
|
||||
// Image contains base64-encoded image data for image generation
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Step is the current step in image generation
|
||||
Step int `json:"step,omitempty"`
|
||||
|
||||
// TotalSteps is the total number of steps for image generation
|
||||
TotalSteps int `json:"total_steps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -50,6 +50,11 @@ type EmbedWriter struct {
|
||||
encodingFormat string
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
done bool
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
@@ -274,6 +279,36 @@ func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Image generation doesn't support streaming in the OpenAI API sense,
|
||||
// so we only write the response when done with images
|
||||
if generateResponse.Done && len(generateResponse.Images) > 0 {
|
||||
w.done = true
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
@@ -393,6 +428,43 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq := openai.FromImageGenerationRequest(req)
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
|
||||
@@ -961,3 +961,143 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerationsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
streamFalse := false
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image generation handler",
|
||||
body: `{
|
||||
"model": "flux",
|
||||
"prompt": "a cat"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "flux",
|
||||
Prompt: "a cat",
|
||||
Stream: &streamFalse,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation with size",
|
||||
body: `{
|
||||
"model": "flux",
|
||||
"prompt": "a dog",
|
||||
"size": "512x512"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "flux",
|
||||
Prompt: "a dog",
|
||||
Stream: &streamFalse,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing prompt error",
|
||||
body: `{
|
||||
"model": "flux"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing model error",
|
||||
body: `{
|
||||
"prompt": "a cat"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatalf("requests did not match\nExpected: %+v\nActual: %+v", tc.req, *capturedRequest)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatalf("errors did not match\nExpected: %+v\nActual: %+v", tc.err, errResp)
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageWriterIntegration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("transforms generate response to openai format", func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware())
|
||||
router.POST("/api/generate", func(c *gin.Context) {
|
||||
// Simulate an image generation response
|
||||
generateResponse := api.GenerateResponse{
|
||||
Done: true,
|
||||
CreatedAt: time.Now(),
|
||||
Images: []string{"base64encodedimage"},
|
||||
}
|
||||
c.JSON(http.StatusOK, generateResponse)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(`{"model":"flux","prompt":"a cat"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var response openai.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(response.Data) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(response.Data))
|
||||
}
|
||||
if response.Data[0].B64JSON != "base64encodedimage" {
|
||||
t.Fatalf("expected image data 'base64encodedimage', got '%s'", response.Data[0].B64JSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -737,3 +737,46 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
|
||||
stream := false
|
||||
return api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Stream: &stream,
|
||||
}
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
|
||||
data := make([]ImageURLOrData, 0)
|
||||
for _, img := range resp.Images {
|
||||
data = append(data, ImageURLOrData{B64JSON: img})
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
@@ -71,12 +71,10 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
|
||||
# Copy MLX libraries to same directory as executable for dlopen
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
|
||||
@@ -84,10 +82,12 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/*; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
@@ -154,6 +154,7 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
@@ -165,12 +166,13 @@ _build_macapp() {
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
@@ -178,7 +180,7 @@ _build_macapp() {
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
|
||||
@@ -523,6 +523,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
Truncate: req.Truncate == nil || *req.Truncate,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@@ -538,6 +541,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||
}
|
||||
|
||||
// Image generation fields
|
||||
if cr.Image != "" {
|
||||
res.Images = []string{cr.Image}
|
||||
}
|
||||
if cr.TotalSteps > 0 {
|
||||
res.Status = "generating image"
|
||||
res.Completed = int64(cr.Step)
|
||||
res.Total = int64(cr.TotalSteps)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done)
|
||||
if err != nil {
|
||||
@@ -1594,8 +1607,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// Experimental OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", s.handleImageGeneration)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
@@ -1917,62 +1930,6 @@ func toolCallId() string {
|
||||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGeneration(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, api.Options{}, nil)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size (e.g., "1024x768") into width and height
|
||||
width, height := int32(1024), int32(1024)
|
||||
if req.Size != "" {
|
||||
if _, err := fmt.Sscanf(req.Size, "%dx%d", &width, &height); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid size format, expected WxH"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var image []byte
|
||||
err = runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
}, func(resp llm.CompletionResponse) {
|
||||
if len(resp.Image) > 0 {
|
||||
image = resp.Image
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"created": time.Now().Unix(),
|
||||
"data": []gin.H{{"b64_json": base64.StdEncoding.EncodeToString(image)}},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
|
||||
@@ -91,9 +91,7 @@ func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, ke
|
||||
}
|
||||
|
||||
// generateImageWithOptions generates an image with the given options.
|
||||
// Note: opts are currently unused as the native API doesn't support size parameters.
|
||||
// Use OpenAI-compatible endpoint (/v1/images/generations) for dimension control.
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, _ ImageGenOptions) error {
|
||||
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -102,7 +100,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
// Note: Size is only available via OpenAI-compatible /v1/images/generations endpoint
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -116,26 +116,19 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
var stepBar *progress.StepBar
|
||||
var imageBase64 string
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 && resp.Completed > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -235,12 +228,9 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -255,26 +245,19 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 && resp.Completed > 0 {
|
||||
if stepBar == nil {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
stepBar.Set(int(resp.Completed))
|
||||
}
|
||||
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
// Handle final response with image data
|
||||
if resp.Done && len(resp.Images) > 0 {
|
||||
imageBase64 = resp.Images[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -65,12 +65,12 @@ func (s *utf8Streamer) Flush() string {
|
||||
return result
|
||||
}
|
||||
|
||||
func init() {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
@@ -78,11 +79,6 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MLX initialized successfully
|
||||
if !mlx.IsMLXAvailable() {
|
||||
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
|
||||
}
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx.h"
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// Forward declaration for Go callback
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlx provides Go bindings for the MLX-C library with dynamic loading support.
|
||||
//
|
||||
//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c
|
||||
package mlx
|
||||
@@ -1,439 +0,0 @@
|
||||
//go:build ignore
|
||||
|
||||
// This tool generates MLX-C dynamic loading wrappers.
|
||||
// Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
func findHeaders(directory string) ([]string, error) {
|
||||
var headers []string
|
||||
err := filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return headers, err
|
||||
}
|
||||
|
||||
func cleanContent(content string) string {
|
||||
// Remove single-line comments
|
||||
re := regexp.MustCompile(`//.*?\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
|
||||
// Remove multi-line comments
|
||||
re = regexp.MustCompile(`/\*.*?\*/`)
|
||||
content = re.ReplaceAllString(content, "")
|
||||
|
||||
// Remove preprocessor directives (lines starting with #) - use multiline mode
|
||||
re = regexp.MustCompile(`(?m)^\s*#.*?$`)
|
||||
content = re.ReplaceAllString(content, "")
|
||||
|
||||
// Remove extern "C" { and } blocks more conservatively
|
||||
// Only remove the extern "C" { line, not the content inside
|
||||
re = regexp.MustCompile(`extern\s+"C"\s*\{\s*?\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
// Remove standalone closing braces that are not part of function declarations
|
||||
re = regexp.MustCompile(`\n\s*\}\s*\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
|
||||
// Collapse whitespace and newlines
|
||||
re = regexp.MustCompile(`\s+`)
|
||||
content = re.ReplaceAllString(content, " ")
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
func extractParamNames(params string) []string {
|
||||
if params == "" || strings.TrimSpace(params) == "void" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var names []string
|
||||
|
||||
// Split by comma, but respect parentheses (for function pointers)
|
||||
parts := splitParams(params)
|
||||
|
||||
// Remove array brackets
|
||||
arrayBrackets := regexp.MustCompile(`\[.*?\]`)
|
||||
|
||||
// Function pointer pattern
|
||||
funcPtrPattern := regexp.MustCompile(`\(\s*\*\s*(\w+)\s*\)`)
|
||||
|
||||
// Type keywords to skip
|
||||
typeKeywords := map[string]bool{
|
||||
"const": true,
|
||||
"struct": true,
|
||||
"unsigned": true,
|
||||
"signed": true,
|
||||
"long": true,
|
||||
"short": true,
|
||||
"int": true,
|
||||
"char": true,
|
||||
"float": true,
|
||||
"double": true,
|
||||
"void": true,
|
||||
"size_t": true,
|
||||
"uint8_t": true,
|
||||
"uint16_t": true,
|
||||
"uint32_t": true,
|
||||
"uint64_t": true,
|
||||
"int8_t": true,
|
||||
"int16_t": true,
|
||||
"int32_t": true,
|
||||
"int64_t": true,
|
||||
"intptr_t": true,
|
||||
"uintptr_t": true,
|
||||
}
|
||||
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove array brackets
|
||||
part = arrayBrackets.ReplaceAllString(part, "")
|
||||
|
||||
// For function pointers like "void (*callback)(int)"
|
||||
if matches := funcPtrPattern.FindStringSubmatch(part); len(matches) > 1 {
|
||||
names = append(names, matches[1])
|
||||
continue
|
||||
}
|
||||
|
||||
// Regular parameter: last identifier
|
||||
tokens := regexp.MustCompile(`\w+`).FindAllString(part, -1)
|
||||
if len(tokens) > 0 {
|
||||
// The last token is usually the parameter name
|
||||
// Skip type keywords
|
||||
for i := len(tokens) - 1; i >= 0; i-- {
|
||||
if !typeKeywords[tokens[i]] {
|
||||
names = append(names, tokens[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
func splitParams(params string) []string {
|
||||
var parts []string
|
||||
var current bytes.Buffer
|
||||
depth := 0
|
||||
|
||||
for _, char := range params + "," {
|
||||
switch char {
|
||||
case '(':
|
||||
depth++
|
||||
current.WriteRune(char)
|
||||
case ')':
|
||||
depth--
|
||||
current.WriteRune(char)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
parts = append(parts, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(char)
|
||||
}
|
||||
default:
|
||||
current.WriteRune(char)
|
||||
}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
func parseFunctions(content string) []Function {
|
||||
var functions []Function
|
||||
|
||||
// Match function declarations: return_type function_name(params);
|
||||
// Matches both mlx_* and _mlx_* functions
|
||||
pattern := regexp.MustCompile(`\b((?:const\s+)?(?:struct\s+)?[\w\s]+?[\*\s]*)\s+(_?mlx_\w+)\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)\s*;`)
|
||||
|
||||
matches := pattern.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range matches {
|
||||
returnType := strings.TrimSpace(match[1])
|
||||
funcName := strings.TrimSpace(match[2])
|
||||
params := strings.TrimSpace(match[3])
|
||||
|
||||
// Skip if this looks like a variable declaration
|
||||
if params == "" || strings.Contains(params, "{") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up return type
|
||||
returnType = strings.Join(strings.Fields(returnType), " ")
|
||||
|
||||
// Extract parameter names
|
||||
paramNames := extractParamNames(params)
|
||||
|
||||
// Check if ARM64 guard is needed
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
return functions
|
||||
}
|
||||
|
||||
func needsARM64Guard(name, retType, params string) bool {
|
||||
return strings.Contains(name, "float16") ||
|
||||
strings.Contains(name, "bfloat16") ||
|
||||
strings.Contains(retType, "float16_t") ||
|
||||
strings.Contains(retType, "bfloat16_t") ||
|
||||
strings.Contains(params, "float16_t") ||
|
||||
strings.Contains(params, "bfloat16_t")
|
||||
}
|
||||
|
||||
func generateWrapperFiles(functions []Function, headerPath, implPath string) error {
|
||||
// Generate header file
|
||||
var headerBuf bytes.Buffer
|
||||
|
||||
headerBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
|
||||
headerBuf.WriteString("// This file provides wrapper declarations for MLX-C functions that use dlopen/dlsym\n")
|
||||
headerBuf.WriteString("//\n")
|
||||
headerBuf.WriteString("// Strategy: Include MLX-C headers for type definitions, then provide wrapper\n")
|
||||
headerBuf.WriteString("// functions that shadow the originals, allowing Go code to call them directly (e.g., C.mlx_add).\n")
|
||||
headerBuf.WriteString("// Function pointers are defined in mlx.c (single compilation unit).\n\n")
|
||||
headerBuf.WriteString("#ifndef MLX_WRAPPERS_H\n")
|
||||
headerBuf.WriteString("#define MLX_WRAPPERS_H\n\n")
|
||||
|
||||
headerBuf.WriteString("// Include MLX headers for type definitions and original declarations\n")
|
||||
headerBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
headerBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
headerBuf.WriteString("#include <stdio.h>\n\n")
|
||||
|
||||
// Undef all MLX functions to avoid conflicts
|
||||
headerBuf.WriteString("// Undefine any existing MLX function macros\n")
|
||||
for _, fn := range functions {
|
||||
headerBuf.WriteString(fmt.Sprintf("#undef %s\n", fn.Name))
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
|
||||
// Function pointer extern declarations
|
||||
headerBuf.WriteString("// Function pointer declarations (defined in mlx.c, loaded via dlsym)\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
headerBuf.WriteString(fmt.Sprintf("extern %s (*%s_ptr)(%s);\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
|
||||
// Initialization function declaration
|
||||
headerBuf.WriteString("// Initialize all function pointers via dlsym (defined in mlx.c)\n")
|
||||
headerBuf.WriteString("int mlx_load_functions(void* handle);\n\n")
|
||||
|
||||
// Wrapper function declarations
|
||||
headerBuf.WriteString("// Wrapper function declarations that call through function pointers\n")
|
||||
headerBuf.WriteString("// Go code calls these directly as C.mlx_* (no #define redirection needed)\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
headerBuf.WriteString(fmt.Sprintf("%s %s(%s);\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#endif\n")
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
headerBuf.WriteString("#endif // MLX_WRAPPERS_H\n")
|
||||
|
||||
// Write header file
|
||||
if err := os.WriteFile(headerPath, headerBuf.Bytes(), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write header file: %w", err)
|
||||
}
|
||||
|
||||
// Generate implementation file
|
||||
var implBuf bytes.Buffer
|
||||
|
||||
implBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
|
||||
implBuf.WriteString("// This file contains the function pointer definitions and initialization\n")
|
||||
implBuf.WriteString("// All function pointers are in a single compilation unit to avoid duplication\n\n")
|
||||
|
||||
implBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
implBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
implBuf.WriteString("#include <stdio.h>\n")
|
||||
implBuf.WriteString("#include <dlfcn.h>\n\n")
|
||||
|
||||
// Function pointer definitions
|
||||
implBuf.WriteString("// Function pointer definitions\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf("%s (*%s_ptr)(%s) = NULL;\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
implBuf.WriteString("\n")
|
||||
|
||||
// Initialization function
|
||||
implBuf.WriteString("// Initialize all function pointers via dlsym\n")
|
||||
implBuf.WriteString("int mlx_load_functions(void* handle) {\n")
|
||||
implBuf.WriteString(" if (handle == NULL) {\n")
|
||||
implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n")
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
implBuf.WriteString(" }\n\n")
|
||||
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name))
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
implBuf.WriteString(" }\n")
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
|
||||
implBuf.WriteString(" return 0;\n")
|
||||
implBuf.WriteString("}\n\n")
|
||||
|
||||
// Wrapper function implementations
|
||||
implBuf.WriteString("// Wrapper function implementations that call through function pointers\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf("%s %s(%s) {\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
|
||||
// Call through function pointer
|
||||
if fn.ReturnType != "void" {
|
||||
implBuf.WriteString(fmt.Sprintf(" return %s_ptr(", fn.Name))
|
||||
} else {
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr(", fn.Name))
|
||||
}
|
||||
|
||||
// Pass parameters
|
||||
implBuf.WriteString(strings.Join(fn.ParamNames, ", "))
|
||||
implBuf.WriteString(");\n")
|
||||
implBuf.WriteString("}\n")
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
implBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write implementation file
|
||||
if err := os.WriteFile(implPath, implBuf.Bytes(), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write implementation file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]\n")
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Generate MLX-C dynamic loading wrappers.\n\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) < 2 {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "ERROR: Missing required arguments\n\n")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
headerDir := args[0]
|
||||
outputHeader := args[1]
|
||||
// Default implementation file is same name with .c extension
|
||||
outputImpl := outputHeader
|
||||
if len(args) > 2 {
|
||||
outputImpl = args[2]
|
||||
} else if strings.HasSuffix(outputHeader, ".h") {
|
||||
outputImpl = outputHeader[:len(outputHeader)-2] + ".c"
|
||||
}
|
||||
|
||||
// Check if header directory exists
|
||||
if _, err := os.Stat(headerDir); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: MLX-C headers directory not found at: %s\n\n", headerDir)
|
||||
fmt.Fprintf(os.Stderr, "Please run CMake first to download MLX-C dependencies:\n")
|
||||
fmt.Fprintf(os.Stderr, " cmake -B build\n\n")
|
||||
fmt.Fprintf(os.Stderr, "The CMake build will download and extract MLX-C headers needed for wrapper generation.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Parsing MLX-C headers from: %s\n", headerDir)
|
||||
|
||||
// Find all headers
|
||||
headers, err := findHeaders(headerDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Failed to find header files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Found %d header files\n", len(headers))
|
||||
|
||||
// Parse all headers
|
||||
var allFunctions []Function
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, header := range headers {
|
||||
content, err := os.ReadFile(header)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading %s: %v\n", header, err)
|
||||
continue
|
||||
}
|
||||
|
||||
cleaned := cleanContent(string(content))
|
||||
functions := parseFunctions(cleaned)
|
||||
|
||||
// Deduplicate
|
||||
for _, fn := range functions {
|
||||
if !seen[fn.Name] {
|
||||
seen[fn.Name] = true
|
||||
allFunctions = append(allFunctions, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Found %d unique function declarations\n", len(allFunctions))
|
||||
|
||||
// Generate wrapper files
|
||||
if err := generateWrapperFiles(allFunctions, outputHeader, outputImpl); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Failed to generate wrapper files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Generated %s and %s successfully\n", outputHeader, outputImpl)
|
||||
}
|
||||
5786
x/imagegen/mlx/mlx.c
5786
x/imagegen/mlx/mlx.c
File diff suppressed because it is too large
Load Diff
@@ -3,13 +3,12 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR}
|
||||
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src
|
||||
#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/
|
||||
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
|
||||
#cgo linux LDFLAGS: -lstdc++ -ldl
|
||||
#cgo windows LDFLAGS: -lstdc++
|
||||
#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc
|
||||
|
||||
// Use generated wrappers instead of direct MLX headers
|
||||
#include "mlx.h"
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
@@ -43,6 +42,192 @@ static inline mlx_stream cpu_stream() {
|
||||
// CGO noescape/nocallback hints to reduce CGO overhead
|
||||
// noescape: pointers won't escape, no heap allocation needed
|
||||
// nocallback: function won't call back into Go
|
||||
#cgo noescape mlx_add
|
||||
#cgo nocallback mlx_add
|
||||
#cgo noescape mlx_subtract
|
||||
#cgo nocallback mlx_subtract
|
||||
#cgo noescape mlx_multiply
|
||||
#cgo nocallback mlx_multiply
|
||||
#cgo noescape mlx_divide
|
||||
#cgo nocallback mlx_divide
|
||||
#cgo noescape mlx_negative
|
||||
#cgo nocallback mlx_negative
|
||||
#cgo noescape mlx_abs
|
||||
#cgo nocallback mlx_abs
|
||||
#cgo noescape mlx_exp
|
||||
#cgo nocallback mlx_exp
|
||||
#cgo noescape mlx_log
|
||||
#cgo nocallback mlx_log
|
||||
#cgo noescape mlx_sqrt
|
||||
#cgo nocallback mlx_sqrt
|
||||
#cgo noescape mlx_rsqrt
|
||||
#cgo nocallback mlx_rsqrt
|
||||
#cgo noescape mlx_square
|
||||
#cgo nocallback mlx_square
|
||||
#cgo noescape mlx_power
|
||||
#cgo nocallback mlx_power
|
||||
#cgo noescape mlx_erf
|
||||
#cgo nocallback mlx_erf
|
||||
#cgo noescape mlx_sigmoid
|
||||
#cgo nocallback mlx_sigmoid
|
||||
#cgo noescape mlx_tanh
|
||||
#cgo nocallback mlx_tanh
|
||||
#cgo noescape mlx_sin
|
||||
#cgo nocallback mlx_sin
|
||||
#cgo noescape mlx_cos
|
||||
#cgo nocallback mlx_cos
|
||||
#cgo noescape mlx_maximum
|
||||
#cgo nocallback mlx_maximum
|
||||
#cgo noescape mlx_minimum
|
||||
#cgo nocallback mlx_minimum
|
||||
#cgo noescape mlx_clip
|
||||
#cgo nocallback mlx_clip
|
||||
#cgo noescape mlx_sum
|
||||
#cgo nocallback mlx_sum
|
||||
#cgo noescape mlx_sum_axis
|
||||
#cgo nocallback mlx_sum_axis
|
||||
#cgo noescape mlx_mean
|
||||
#cgo nocallback mlx_mean
|
||||
#cgo noescape mlx_mean_axis
|
||||
#cgo nocallback mlx_mean_axis
|
||||
#cgo noescape mlx_var_axis
|
||||
#cgo nocallback mlx_var_axis
|
||||
#cgo noescape mlx_argmax
|
||||
#cgo nocallback mlx_argmax
|
||||
#cgo noescape mlx_argmax_axis
|
||||
#cgo nocallback mlx_argmax_axis
|
||||
#cgo noescape mlx_softmax_axis
|
||||
#cgo nocallback mlx_softmax_axis
|
||||
#cgo noescape mlx_cumsum
|
||||
#cgo nocallback mlx_cumsum
|
||||
#cgo noescape mlx_matmul
|
||||
#cgo nocallback mlx_matmul
|
||||
#cgo noescape mlx_addmm
|
||||
#cgo nocallback mlx_addmm
|
||||
#cgo noescape mlx_gather_mm
|
||||
#cgo nocallback mlx_gather_mm
|
||||
#cgo noescape mlx_gather_qmm
|
||||
#cgo nocallback mlx_gather_qmm
|
||||
#cgo noescape mlx_reshape
|
||||
#cgo nocallback mlx_reshape
|
||||
#cgo noescape mlx_transpose_axes
|
||||
#cgo nocallback mlx_transpose_axes
|
||||
#cgo noescape mlx_expand_dims
|
||||
#cgo nocallback mlx_expand_dims
|
||||
#cgo noescape mlx_squeeze_axis
|
||||
#cgo nocallback mlx_squeeze_axis
|
||||
#cgo noescape mlx_flatten
|
||||
#cgo nocallback mlx_flatten
|
||||
#cgo noescape mlx_concatenate_axis
|
||||
#cgo nocallback mlx_concatenate_axis
|
||||
#cgo noescape mlx_slice
|
||||
#cgo nocallback mlx_slice
|
||||
#cgo noescape mlx_slice_update
|
||||
#cgo nocallback mlx_slice_update
|
||||
#cgo noescape mlx_as_strided
|
||||
#cgo nocallback mlx_as_strided
|
||||
#cgo noescape mlx_view
|
||||
#cgo nocallback mlx_view
|
||||
#cgo noescape mlx_contiguous
|
||||
#cgo nocallback mlx_contiguous
|
||||
#cgo noescape mlx_pad
|
||||
#cgo nocallback mlx_pad
|
||||
#cgo noescape mlx_tile
|
||||
#cgo nocallback mlx_tile
|
||||
#cgo noescape mlx_take_axis
|
||||
#cgo nocallback mlx_take_axis
|
||||
#cgo noescape mlx_take_along_axis
|
||||
#cgo nocallback mlx_take_along_axis
|
||||
#cgo noescape mlx_put_along_axis
|
||||
#cgo nocallback mlx_put_along_axis
|
||||
#cgo noescape mlx_where
|
||||
#cgo nocallback mlx_where
|
||||
#cgo noescape mlx_argsort_axis
|
||||
#cgo nocallback mlx_argsort_axis
|
||||
#cgo noescape mlx_argpartition_axis
|
||||
#cgo nocallback mlx_argpartition_axis
|
||||
#cgo noescape mlx_topk_axis
|
||||
#cgo nocallback mlx_topk_axis
|
||||
#cgo noescape mlx_less
|
||||
#cgo nocallback mlx_less
|
||||
#cgo noescape mlx_greater_equal
|
||||
#cgo nocallback mlx_greater_equal
|
||||
#cgo noescape mlx_logical_and
|
||||
#cgo nocallback mlx_logical_and
|
||||
#cgo noescape mlx_zeros
|
||||
#cgo nocallback mlx_zeros
|
||||
#cgo noescape mlx_zeros_like
|
||||
#cgo nocallback mlx_zeros_like
|
||||
#cgo noescape mlx_ones
|
||||
#cgo nocallback mlx_ones
|
||||
#cgo noescape mlx_full
|
||||
#cgo nocallback mlx_full
|
||||
#cgo noescape mlx_arange
|
||||
#cgo nocallback mlx_arange
|
||||
#cgo noescape mlx_linspace
|
||||
#cgo nocallback mlx_linspace
|
||||
#cgo noescape mlx_tri
|
||||
#cgo nocallback mlx_tri
|
||||
#cgo noescape mlx_astype
|
||||
#cgo nocallback mlx_astype
|
||||
#cgo noescape mlx_fast_rms_norm
|
||||
#cgo nocallback mlx_fast_rms_norm
|
||||
#cgo noescape mlx_fast_rope
|
||||
#cgo nocallback mlx_fast_rope
|
||||
#cgo noescape mlx_fast_scaled_dot_product_attention
|
||||
#cgo nocallback mlx_fast_scaled_dot_product_attention
|
||||
#cgo noescape mlx_conv2d
|
||||
#cgo nocallback mlx_conv2d
|
||||
#cgo noescape mlx_conv3d
|
||||
#cgo nocallback mlx_conv3d
|
||||
#cgo noescape mlx_random_key
|
||||
#cgo nocallback mlx_random_key
|
||||
#cgo noescape mlx_random_split
|
||||
#cgo nocallback mlx_random_split
|
||||
#cgo noescape mlx_random_categorical_num_samples
|
||||
#cgo nocallback mlx_random_categorical_num_samples
|
||||
#cgo noescape mlx_random_normal
|
||||
#cgo nocallback mlx_random_normal
|
||||
#cgo noescape mlx_random_uniform
|
||||
#cgo nocallback mlx_random_uniform
|
||||
#cgo noescape mlx_array_eval
|
||||
#cgo nocallback mlx_array_eval
|
||||
#cgo noescape mlx_eval
|
||||
#cgo nocallback mlx_eval
|
||||
#cgo noescape mlx_async_eval
|
||||
#cgo nocallback mlx_async_eval
|
||||
#cgo noescape mlx_synchronize
|
||||
#cgo nocallback mlx_synchronize
|
||||
#cgo noescape mlx_array_new
|
||||
#cgo nocallback mlx_array_new
|
||||
#cgo noescape mlx_array_new_data
|
||||
#cgo nocallback mlx_array_new_data
|
||||
#cgo noescape mlx_array_new_float
|
||||
#cgo nocallback mlx_array_new_float
|
||||
#cgo noescape mlx_array_free
|
||||
#cgo nocallback mlx_array_free
|
||||
#cgo noescape mlx_array_size
|
||||
#cgo nocallback mlx_array_size
|
||||
#cgo noescape mlx_array_ndim
|
||||
#cgo nocallback mlx_array_ndim
|
||||
#cgo noescape mlx_array_dim
|
||||
#cgo nocallback mlx_array_dim
|
||||
#cgo noescape mlx_array_dtype
|
||||
#cgo nocallback mlx_array_dtype
|
||||
#cgo noescape mlx_array_item_int32
|
||||
#cgo nocallback mlx_array_item_int32
|
||||
#cgo noescape mlx_vector_array_new_data
|
||||
#cgo nocallback mlx_vector_array_new_data
|
||||
#cgo noescape mlx_vector_array_free
|
||||
#cgo nocallback mlx_vector_array_free
|
||||
#cgo noescape mlx_array_new_int
|
||||
#cgo nocallback mlx_array_new_int
|
||||
#cgo noescape mlx_stream_new_device
|
||||
#cgo nocallback mlx_stream_new_device
|
||||
#cgo noescape mlx_get_default_stream
|
||||
#cgo nocallback mlx_get_default_stream
|
||||
#cgo noescape mlx_set_default_stream
|
||||
#cgo nocallback mlx_set_default_stream
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
@@ -1611,57 +1796,7 @@ func ArgmaxKeepArray(logits *Array) *Array {
|
||||
var RandomState = []*Array{nil}
|
||||
var randomStateMu sync.Mutex
|
||||
|
||||
var mlxInitialized bool
|
||||
var mlxInitError error
|
||||
|
||||
// InitMLX initializes the MLX library by dynamically loading libmlxc.
|
||||
// This must be called before using any MLX functions.
|
||||
// Returns an error if the library cannot be loaded.
|
||||
func InitMLX() error {
|
||||
if mlxInitialized {
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
// Try to load the MLX dynamic library
|
||||
ret := C.mlx_dynamic_init()
|
||||
if ret != 0 {
|
||||
errMsg := C.GoString(C.mlx_dynamic_error())
|
||||
mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg)
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
// Initialize all function pointers via dlsym
|
||||
handle := C.mlx_get_handle()
|
||||
ret = C.mlx_load_functions(handle)
|
||||
if ret != 0 {
|
||||
mlxInitError = fmt.Errorf("failed to load MLX function symbols")
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
mlxInitialized = true
|
||||
mlxInitError = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsMLXAvailable returns whether MLX was successfully initialized
|
||||
func IsMLXAvailable() bool {
|
||||
return mlxInitialized && mlxInitError == nil
|
||||
}
|
||||
|
||||
// GetMLXInitError returns any error that occurred during MLX initialization
|
||||
func GetMLXInitError() error {
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Initialize MLX dynamic library first
|
||||
if err := InitMLX(); err != nil {
|
||||
// Don't panic in init - let the caller handle the error
|
||||
// Store the error for later retrieval
|
||||
mlxInitError = err
|
||||
return
|
||||
}
|
||||
|
||||
// Lock main goroutine to OS thread for CUDA context stability.
|
||||
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
|
||||
runtime.LockOSThread()
|
||||
|
||||
2337
x/imagegen/mlx/mlx.h
2337
x/imagegen/mlx/mlx.h
File diff suppressed because it is too large
Load Diff
@@ -1,144 +0,0 @@
|
||||
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
|
||||
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
|
||||
|
||||
#include "mlx_dynamic.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
#ifdef __APPLE__
|
||||
// Get path to library in same directory as executable
|
||||
static char* get_exe_relative_path(const char* libname) {
|
||||
static char path[1024];
|
||||
uint32_t size = sizeof(path);
|
||||
if (_NSGetExecutablePath(path, &size) != 0) {
|
||||
return NULL;
|
||||
}
|
||||
// Get directory of executable
|
||||
char* dir = dirname(path);
|
||||
static char fullpath[1024];
|
||||
snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname);
|
||||
return fullpath;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Try to load library from a specific path
|
||||
static int try_load_lib(const char* path) {
|
||||
if (!path) return 0;
|
||||
mlx_handle = LOAD_LIB(path);
|
||||
return mlx_handle != NULL;
|
||||
}
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
const char* lib_path = NULL;
|
||||
const char* tried_paths[8] = {0};
|
||||
int num_tried = 0;
|
||||
|
||||
#ifdef _WIN32
|
||||
// Windows: try same directory as executable
|
||||
lib_path = "libmlxc.dll";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#elif defined(__APPLE__)
|
||||
// macOS: try executable directory first
|
||||
lib_path = get_exe_relative_path("libmlxc.dylib");
|
||||
if (lib_path) {
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
}
|
||||
// Try build directory (for tests run from repo root)
|
||||
lib_path = "./build/lib/ollama/libmlxc.dylib";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
// Fallback to system paths
|
||||
lib_path = "libmlxc.dylib";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#else
|
||||
// Linux: try build directory first (for tests)
|
||||
lib_path = "./build/lib/ollama/libmlxc.so";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
// Fallback to system paths
|
||||
lib_path = "libmlxc.so";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#endif
|
||||
|
||||
// Failed to load library - build error message with all tried paths
|
||||
{
|
||||
const char* err = LIB_ERROR();
|
||||
int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. Tried: ");
|
||||
for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) {
|
||||
offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
|
||||
"%s%s", i > 0 ? ", " : "", tried_paths[i]);
|
||||
}
|
||||
if (err) {
|
||||
snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
|
||||
". Last error: %s", err);
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
|
||||
success:
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", lib_path ? lib_path : "library");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Get the library handle (for use by generated wrappers)
|
||||
void* mlx_get_handle(void) {
|
||||
return mlx_handle;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Get the library handle (for use by generated wrappers)
|
||||
void* mlx_get_handle(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
@@ -4,30 +4,9 @@ package mlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping MLX tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive.
|
||||
func TestBasicCleanup(t *testing.T) {
|
||||
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
||||
|
||||
@@ -3,33 +3,12 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
|
||||
@@ -3,35 +3,13 @@
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
||||
func TestComputeAxisFreqs(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
|
||||
@@ -3,34 +3,12 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping nn tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
|
||||
func TestLinearNoBias(t *testing.T) {
|
||||
// Weight: [out=2, in=3] -> transposed at forward time
|
||||
|
||||
@@ -62,12 +62,6 @@ func Execute(args []string) error {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
err := mlx.InitMLX()
|
||||
if err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||
|
||||
// Check memory requirements before loading
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -62,7 +61,7 @@ func NewServer(modelName string) (*Server, error) {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the current executable path (we use the same binary with runner subcommand)
|
||||
// Get the ollama-mlx executable path (in same directory as current executable)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
@@ -70,9 +69,10 @@ func NewServer(modelName string) (*Server, error) {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
|
||||
|
||||
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
@@ -134,7 +134,7 @@ func NewServer(modelName string) (*Server, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
|
||||
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||
}
|
||||
@@ -232,11 +232,13 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
}
|
||||
|
||||
@@ -279,15 +281,11 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
Total: raw.Total,
|
||||
}
|
||||
if raw.Image != "" {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw.Image); err == nil {
|
||||
cresp.Image = data
|
||||
}
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
include(FetchContent)
|
||||
|
||||
# Read MLX version from top-level file (shared with Dockerfile)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
@@ -54,7 +50,7 @@ endif()
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG})
|
||||
GIT_TAG v0.4.1)
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
|
||||
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
|
||||
|
||||
#include "mlx_dynamic.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.dylib",
|
||||
"@loader_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"@executable_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"build/lib/ollama/libmlxc.dylib",
|
||||
"../build/lib/ollama/libmlxc.dylib",
|
||||
NULL
|
||||
};
|
||||
#else
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.so",
|
||||
"$ORIGIN/../build/lib/ollama/libmlxc.so",
|
||||
"build/lib/ollama/libmlxc.so",
|
||||
"../build/lib/ollama/libmlxc.so",
|
||||
NULL
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
// Try each possible library path
|
||||
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
|
||||
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
|
||||
if (mlx_handle != NULL) {
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", LIB_NAMES[i]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to load library
|
||||
const char* err = LIB_ERROR();
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. %s",
|
||||
err ? err : "Unknown error");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
Reference in New Issue
Block a user