mirror of
https://github.com/ollama/ollama.git
synced 2026-01-16 19:41:24 -05:00
Compare commits
1 Commits
imagegen-g
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12719b6e87 |
18
Dockerfile
18
Dockerfile
@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
# install epel-release for ccache
|
||||
RUN yum install -y yum-utils epel-release \
|
||||
&& dnf install -y clang ccache \
|
||||
&& dnf install -y clang ccache git \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
|
||||
ENV CC=clang CXX=clang++
|
||||
|
||||
@@ -149,6 +149,7 @@ COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
@@ -156,14 +157,6 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN mkdir -p dist/bin
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -172,12 +165,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -185,7 +180,6 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
1
MLX_VERSION
Normal file
1
MLX_VERSION
Normal file
@@ -0,0 +1 @@
|
||||
v0.4.1
|
||||
@@ -270,10 +270,10 @@ cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
|
||||
|
||||
```shell
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
@@ -60,7 +60,7 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
@@ -71,10 +71,12 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
|
||||
# Copy MLX libraries to same directory as executable for dlopen
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
||||
done
|
||||
}
|
||||
|
||||
@@ -82,12 +84,10 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/*; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
@@ -154,7 +154,6 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
@@ -166,13 +165,12 @@ _build_macapp() {
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
@@ -180,7 +178,7 @@ _build_macapp() {
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
|
||||
@@ -65,12 +65,12 @@ func (s *utf8Streamer) Flush() string {
|
||||
return result
|
||||
}
|
||||
|
||||
func init() {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
@@ -79,6 +78,11 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MLX initialized successfully
|
||||
if !mlx.IsMLXAvailable() {
|
||||
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
|
||||
}
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
#include "mlx.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// Forward declaration for Go callback
|
||||
|
||||
6
x/imagegen/mlx/doc.go
Normal file
6
x/imagegen/mlx/doc.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlx provides Go bindings for the MLX-C library with dynamic loading support.
|
||||
//
|
||||
//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c
|
||||
package mlx
|
||||
439
x/imagegen/mlx/generate_wrappers.go
Normal file
439
x/imagegen/mlx/generate_wrappers.go
Normal file
@@ -0,0 +1,439 @@
|
||||
//go:build ignore
|
||||
|
||||
// This tool generates MLX-C dynamic loading wrappers.
|
||||
// Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
func findHeaders(directory string) ([]string, error) {
|
||||
var headers []string
|
||||
err := filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return headers, err
|
||||
}
|
||||
|
||||
func cleanContent(content string) string {
|
||||
// Remove single-line comments
|
||||
re := regexp.MustCompile(`//.*?\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
|
||||
// Remove multi-line comments
|
||||
re = regexp.MustCompile(`/\*.*?\*/`)
|
||||
content = re.ReplaceAllString(content, "")
|
||||
|
||||
// Remove preprocessor directives (lines starting with #) - use multiline mode
|
||||
re = regexp.MustCompile(`(?m)^\s*#.*?$`)
|
||||
content = re.ReplaceAllString(content, "")
|
||||
|
||||
// Remove extern "C" { and } blocks more conservatively
|
||||
// Only remove the extern "C" { line, not the content inside
|
||||
re = regexp.MustCompile(`extern\s+"C"\s*\{\s*?\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
// Remove standalone closing braces that are not part of function declarations
|
||||
re = regexp.MustCompile(`\n\s*\}\s*\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
|
||||
// Collapse whitespace and newlines
|
||||
re = regexp.MustCompile(`\s+`)
|
||||
content = re.ReplaceAllString(content, " ")
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
func extractParamNames(params string) []string {
|
||||
if params == "" || strings.TrimSpace(params) == "void" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var names []string
|
||||
|
||||
// Split by comma, but respect parentheses (for function pointers)
|
||||
parts := splitParams(params)
|
||||
|
||||
// Remove array brackets
|
||||
arrayBrackets := regexp.MustCompile(`\[.*?\]`)
|
||||
|
||||
// Function pointer pattern
|
||||
funcPtrPattern := regexp.MustCompile(`\(\s*\*\s*(\w+)\s*\)`)
|
||||
|
||||
// Type keywords to skip
|
||||
typeKeywords := map[string]bool{
|
||||
"const": true,
|
||||
"struct": true,
|
||||
"unsigned": true,
|
||||
"signed": true,
|
||||
"long": true,
|
||||
"short": true,
|
||||
"int": true,
|
||||
"char": true,
|
||||
"float": true,
|
||||
"double": true,
|
||||
"void": true,
|
||||
"size_t": true,
|
||||
"uint8_t": true,
|
||||
"uint16_t": true,
|
||||
"uint32_t": true,
|
||||
"uint64_t": true,
|
||||
"int8_t": true,
|
||||
"int16_t": true,
|
||||
"int32_t": true,
|
||||
"int64_t": true,
|
||||
"intptr_t": true,
|
||||
"uintptr_t": true,
|
||||
}
|
||||
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove array brackets
|
||||
part = arrayBrackets.ReplaceAllString(part, "")
|
||||
|
||||
// For function pointers like "void (*callback)(int)"
|
||||
if matches := funcPtrPattern.FindStringSubmatch(part); len(matches) > 1 {
|
||||
names = append(names, matches[1])
|
||||
continue
|
||||
}
|
||||
|
||||
// Regular parameter: last identifier
|
||||
tokens := regexp.MustCompile(`\w+`).FindAllString(part, -1)
|
||||
if len(tokens) > 0 {
|
||||
// The last token is usually the parameter name
|
||||
// Skip type keywords
|
||||
for i := len(tokens) - 1; i >= 0; i-- {
|
||||
if !typeKeywords[tokens[i]] {
|
||||
names = append(names, tokens[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
func splitParams(params string) []string {
|
||||
var parts []string
|
||||
var current bytes.Buffer
|
||||
depth := 0
|
||||
|
||||
for _, char := range params + "," {
|
||||
switch char {
|
||||
case '(':
|
||||
depth++
|
||||
current.WriteRune(char)
|
||||
case ')':
|
||||
depth--
|
||||
current.WriteRune(char)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
parts = append(parts, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(char)
|
||||
}
|
||||
default:
|
||||
current.WriteRune(char)
|
||||
}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
func parseFunctions(content string) []Function {
|
||||
var functions []Function
|
||||
|
||||
// Match function declarations: return_type function_name(params);
|
||||
// Matches both mlx_* and _mlx_* functions
|
||||
pattern := regexp.MustCompile(`\b((?:const\s+)?(?:struct\s+)?[\w\s]+?[\*\s]*)\s+(_?mlx_\w+)\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)\s*;`)
|
||||
|
||||
matches := pattern.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range matches {
|
||||
returnType := strings.TrimSpace(match[1])
|
||||
funcName := strings.TrimSpace(match[2])
|
||||
params := strings.TrimSpace(match[3])
|
||||
|
||||
// Skip if this looks like a variable declaration
|
||||
if params == "" || strings.Contains(params, "{") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up return type
|
||||
returnType = strings.Join(strings.Fields(returnType), " ")
|
||||
|
||||
// Extract parameter names
|
||||
paramNames := extractParamNames(params)
|
||||
|
||||
// Check if ARM64 guard is needed
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
return functions
|
||||
}
|
||||
|
||||
func needsARM64Guard(name, retType, params string) bool {
|
||||
return strings.Contains(name, "float16") ||
|
||||
strings.Contains(name, "bfloat16") ||
|
||||
strings.Contains(retType, "float16_t") ||
|
||||
strings.Contains(retType, "bfloat16_t") ||
|
||||
strings.Contains(params, "float16_t") ||
|
||||
strings.Contains(params, "bfloat16_t")
|
||||
}
|
||||
|
||||
func generateWrapperFiles(functions []Function, headerPath, implPath string) error {
|
||||
// Generate header file
|
||||
var headerBuf bytes.Buffer
|
||||
|
||||
headerBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
|
||||
headerBuf.WriteString("// This file provides wrapper declarations for MLX-C functions that use dlopen/dlsym\n")
|
||||
headerBuf.WriteString("//\n")
|
||||
headerBuf.WriteString("// Strategy: Include MLX-C headers for type definitions, then provide wrapper\n")
|
||||
headerBuf.WriteString("// functions that shadow the originals, allowing Go code to call them directly (e.g., C.mlx_add).\n")
|
||||
headerBuf.WriteString("// Function pointers are defined in mlx.c (single compilation unit).\n\n")
|
||||
headerBuf.WriteString("#ifndef MLX_WRAPPERS_H\n")
|
||||
headerBuf.WriteString("#define MLX_WRAPPERS_H\n\n")
|
||||
|
||||
headerBuf.WriteString("// Include MLX headers for type definitions and original declarations\n")
|
||||
headerBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
headerBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
headerBuf.WriteString("#include <stdio.h>\n\n")
|
||||
|
||||
// Undef all MLX functions to avoid conflicts
|
||||
headerBuf.WriteString("// Undefine any existing MLX function macros\n")
|
||||
for _, fn := range functions {
|
||||
headerBuf.WriteString(fmt.Sprintf("#undef %s\n", fn.Name))
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
|
||||
// Function pointer extern declarations
|
||||
headerBuf.WriteString("// Function pointer declarations (defined in mlx.c, loaded via dlsym)\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
headerBuf.WriteString(fmt.Sprintf("extern %s (*%s_ptr)(%s);\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
|
||||
// Initialization function declaration
|
||||
headerBuf.WriteString("// Initialize all function pointers via dlsym (defined in mlx.c)\n")
|
||||
headerBuf.WriteString("int mlx_load_functions(void* handle);\n\n")
|
||||
|
||||
// Wrapper function declarations
|
||||
headerBuf.WriteString("// Wrapper function declarations that call through function pointers\n")
|
||||
headerBuf.WriteString("// Go code calls these directly as C.mlx_* (no #define redirection needed)\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
headerBuf.WriteString(fmt.Sprintf("%s %s(%s);\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#endif\n")
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
headerBuf.WriteString("#endif // MLX_WRAPPERS_H\n")
|
||||
|
||||
// Write header file
|
||||
if err := os.WriteFile(headerPath, headerBuf.Bytes(), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write header file: %w", err)
|
||||
}
|
||||
|
||||
// Generate implementation file
|
||||
var implBuf bytes.Buffer
|
||||
|
||||
implBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
|
||||
implBuf.WriteString("// This file contains the function pointer definitions and initialization\n")
|
||||
implBuf.WriteString("// All function pointers are in a single compilation unit to avoid duplication\n\n")
|
||||
|
||||
implBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
implBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
implBuf.WriteString("#include <stdio.h>\n")
|
||||
implBuf.WriteString("#include <dlfcn.h>\n\n")
|
||||
|
||||
// Function pointer definitions
|
||||
implBuf.WriteString("// Function pointer definitions\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf("%s (*%s_ptr)(%s) = NULL;\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
implBuf.WriteString("\n")
|
||||
|
||||
// Initialization function
|
||||
implBuf.WriteString("// Initialize all function pointers via dlsym\n")
|
||||
implBuf.WriteString("int mlx_load_functions(void* handle) {\n")
|
||||
implBuf.WriteString(" if (handle == NULL) {\n")
|
||||
implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n")
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
implBuf.WriteString(" }\n\n")
|
||||
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name))
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
implBuf.WriteString(" }\n")
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
|
||||
implBuf.WriteString(" return 0;\n")
|
||||
implBuf.WriteString("}\n\n")
|
||||
|
||||
// Wrapper function implementations
|
||||
implBuf.WriteString("// Wrapper function implementations that call through function pointers\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf("%s %s(%s) {\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
|
||||
// Call through function pointer
|
||||
if fn.ReturnType != "void" {
|
||||
implBuf.WriteString(fmt.Sprintf(" return %s_ptr(", fn.Name))
|
||||
} else {
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr(", fn.Name))
|
||||
}
|
||||
|
||||
// Pass parameters
|
||||
implBuf.WriteString(strings.Join(fn.ParamNames, ", "))
|
||||
implBuf.WriteString(");\n")
|
||||
implBuf.WriteString("}\n")
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
implBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write implementation file
|
||||
if err := os.WriteFile(implPath, implBuf.Bytes(), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write implementation file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]\n")
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Generate MLX-C dynamic loading wrappers.\n\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) < 2 {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "ERROR: Missing required arguments\n\n")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
headerDir := args[0]
|
||||
outputHeader := args[1]
|
||||
// Default implementation file is same name with .c extension
|
||||
outputImpl := outputHeader
|
||||
if len(args) > 2 {
|
||||
outputImpl = args[2]
|
||||
} else if strings.HasSuffix(outputHeader, ".h") {
|
||||
outputImpl = outputHeader[:len(outputHeader)-2] + ".c"
|
||||
}
|
||||
|
||||
// Check if header directory exists
|
||||
if _, err := os.Stat(headerDir); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: MLX-C headers directory not found at: %s\n\n", headerDir)
|
||||
fmt.Fprintf(os.Stderr, "Please run CMake first to download MLX-C dependencies:\n")
|
||||
fmt.Fprintf(os.Stderr, " cmake -B build\n\n")
|
||||
fmt.Fprintf(os.Stderr, "The CMake build will download and extract MLX-C headers needed for wrapper generation.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Parsing MLX-C headers from: %s\n", headerDir)
|
||||
|
||||
// Find all headers
|
||||
headers, err := findHeaders(headerDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Failed to find header files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Found %d header files\n", len(headers))
|
||||
|
||||
// Parse all headers
|
||||
var allFunctions []Function
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, header := range headers {
|
||||
content, err := os.ReadFile(header)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading %s: %v\n", header, err)
|
||||
continue
|
||||
}
|
||||
|
||||
cleaned := cleanContent(string(content))
|
||||
functions := parseFunctions(cleaned)
|
||||
|
||||
// Deduplicate
|
||||
for _, fn := range functions {
|
||||
if !seen[fn.Name] {
|
||||
seen[fn.Name] = true
|
||||
allFunctions = append(allFunctions, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Found %d unique function declarations\n", len(allFunctions))
|
||||
|
||||
// Generate wrapper files
|
||||
if err := generateWrapperFiles(allFunctions, outputHeader, outputImpl); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Failed to generate wrapper files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Generated %s and %s successfully\n", outputHeader, outputImpl)
|
||||
}
|
||||
5786
x/imagegen/mlx/mlx.c
Normal file
5786
x/imagegen/mlx/mlx.c
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,12 +3,13 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src
|
||||
#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/
|
||||
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR}
|
||||
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
|
||||
#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc
|
||||
#cgo linux LDFLAGS: -lstdc++ -ldl
|
||||
#cgo windows LDFLAGS: -lstdc++
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
// Use generated wrappers instead of direct MLX headers
|
||||
#include "mlx.h"
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
@@ -42,192 +43,6 @@ static inline mlx_stream cpu_stream() {
|
||||
// CGO noescape/nocallback hints to reduce CGO overhead
|
||||
// noescape: pointers won't escape, no heap allocation needed
|
||||
// nocallback: function won't call back into Go
|
||||
#cgo noescape mlx_add
|
||||
#cgo nocallback mlx_add
|
||||
#cgo noescape mlx_subtract
|
||||
#cgo nocallback mlx_subtract
|
||||
#cgo noescape mlx_multiply
|
||||
#cgo nocallback mlx_multiply
|
||||
#cgo noescape mlx_divide
|
||||
#cgo nocallback mlx_divide
|
||||
#cgo noescape mlx_negative
|
||||
#cgo nocallback mlx_negative
|
||||
#cgo noescape mlx_abs
|
||||
#cgo nocallback mlx_abs
|
||||
#cgo noescape mlx_exp
|
||||
#cgo nocallback mlx_exp
|
||||
#cgo noescape mlx_log
|
||||
#cgo nocallback mlx_log
|
||||
#cgo noescape mlx_sqrt
|
||||
#cgo nocallback mlx_sqrt
|
||||
#cgo noescape mlx_rsqrt
|
||||
#cgo nocallback mlx_rsqrt
|
||||
#cgo noescape mlx_square
|
||||
#cgo nocallback mlx_square
|
||||
#cgo noescape mlx_power
|
||||
#cgo nocallback mlx_power
|
||||
#cgo noescape mlx_erf
|
||||
#cgo nocallback mlx_erf
|
||||
#cgo noescape mlx_sigmoid
|
||||
#cgo nocallback mlx_sigmoid
|
||||
#cgo noescape mlx_tanh
|
||||
#cgo nocallback mlx_tanh
|
||||
#cgo noescape mlx_sin
|
||||
#cgo nocallback mlx_sin
|
||||
#cgo noescape mlx_cos
|
||||
#cgo nocallback mlx_cos
|
||||
#cgo noescape mlx_maximum
|
||||
#cgo nocallback mlx_maximum
|
||||
#cgo noescape mlx_minimum
|
||||
#cgo nocallback mlx_minimum
|
||||
#cgo noescape mlx_clip
|
||||
#cgo nocallback mlx_clip
|
||||
#cgo noescape mlx_sum
|
||||
#cgo nocallback mlx_sum
|
||||
#cgo noescape mlx_sum_axis
|
||||
#cgo nocallback mlx_sum_axis
|
||||
#cgo noescape mlx_mean
|
||||
#cgo nocallback mlx_mean
|
||||
#cgo noescape mlx_mean_axis
|
||||
#cgo nocallback mlx_mean_axis
|
||||
#cgo noescape mlx_var_axis
|
||||
#cgo nocallback mlx_var_axis
|
||||
#cgo noescape mlx_argmax
|
||||
#cgo nocallback mlx_argmax
|
||||
#cgo noescape mlx_argmax_axis
|
||||
#cgo nocallback mlx_argmax_axis
|
||||
#cgo noescape mlx_softmax_axis
|
||||
#cgo nocallback mlx_softmax_axis
|
||||
#cgo noescape mlx_cumsum
|
||||
#cgo nocallback mlx_cumsum
|
||||
#cgo noescape mlx_matmul
|
||||
#cgo nocallback mlx_matmul
|
||||
#cgo noescape mlx_addmm
|
||||
#cgo nocallback mlx_addmm
|
||||
#cgo noescape mlx_gather_mm
|
||||
#cgo nocallback mlx_gather_mm
|
||||
#cgo noescape mlx_gather_qmm
|
||||
#cgo nocallback mlx_gather_qmm
|
||||
#cgo noescape mlx_reshape
|
||||
#cgo nocallback mlx_reshape
|
||||
#cgo noescape mlx_transpose_axes
|
||||
#cgo nocallback mlx_transpose_axes
|
||||
#cgo noescape mlx_expand_dims
|
||||
#cgo nocallback mlx_expand_dims
|
||||
#cgo noescape mlx_squeeze_axis
|
||||
#cgo nocallback mlx_squeeze_axis
|
||||
#cgo noescape mlx_flatten
|
||||
#cgo nocallback mlx_flatten
|
||||
#cgo noescape mlx_concatenate_axis
|
||||
#cgo nocallback mlx_concatenate_axis
|
||||
#cgo noescape mlx_slice
|
||||
#cgo nocallback mlx_slice
|
||||
#cgo noescape mlx_slice_update
|
||||
#cgo nocallback mlx_slice_update
|
||||
#cgo noescape mlx_as_strided
|
||||
#cgo nocallback mlx_as_strided
|
||||
#cgo noescape mlx_view
|
||||
#cgo nocallback mlx_view
|
||||
#cgo noescape mlx_contiguous
|
||||
#cgo nocallback mlx_contiguous
|
||||
#cgo noescape mlx_pad
|
||||
#cgo nocallback mlx_pad
|
||||
#cgo noescape mlx_tile
|
||||
#cgo nocallback mlx_tile
|
||||
#cgo noescape mlx_take_axis
|
||||
#cgo nocallback mlx_take_axis
|
||||
#cgo noescape mlx_take_along_axis
|
||||
#cgo nocallback mlx_take_along_axis
|
||||
#cgo noescape mlx_put_along_axis
|
||||
#cgo nocallback mlx_put_along_axis
|
||||
#cgo noescape mlx_where
|
||||
#cgo nocallback mlx_where
|
||||
#cgo noescape mlx_argsort_axis
|
||||
#cgo nocallback mlx_argsort_axis
|
||||
#cgo noescape mlx_argpartition_axis
|
||||
#cgo nocallback mlx_argpartition_axis
|
||||
#cgo noescape mlx_topk_axis
|
||||
#cgo nocallback mlx_topk_axis
|
||||
#cgo noescape mlx_less
|
||||
#cgo nocallback mlx_less
|
||||
#cgo noescape mlx_greater_equal
|
||||
#cgo nocallback mlx_greater_equal
|
||||
#cgo noescape mlx_logical_and
|
||||
#cgo nocallback mlx_logical_and
|
||||
#cgo noescape mlx_zeros
|
||||
#cgo nocallback mlx_zeros
|
||||
#cgo noescape mlx_zeros_like
|
||||
#cgo nocallback mlx_zeros_like
|
||||
#cgo noescape mlx_ones
|
||||
#cgo nocallback mlx_ones
|
||||
#cgo noescape mlx_full
|
||||
#cgo nocallback mlx_full
|
||||
#cgo noescape mlx_arange
|
||||
#cgo nocallback mlx_arange
|
||||
#cgo noescape mlx_linspace
|
||||
#cgo nocallback mlx_linspace
|
||||
#cgo noescape mlx_tri
|
||||
#cgo nocallback mlx_tri
|
||||
#cgo noescape mlx_astype
|
||||
#cgo nocallback mlx_astype
|
||||
#cgo noescape mlx_fast_rms_norm
|
||||
#cgo nocallback mlx_fast_rms_norm
|
||||
#cgo noescape mlx_fast_rope
|
||||
#cgo nocallback mlx_fast_rope
|
||||
#cgo noescape mlx_fast_scaled_dot_product_attention
|
||||
#cgo nocallback mlx_fast_scaled_dot_product_attention
|
||||
#cgo noescape mlx_conv2d
|
||||
#cgo nocallback mlx_conv2d
|
||||
#cgo noescape mlx_conv3d
|
||||
#cgo nocallback mlx_conv3d
|
||||
#cgo noescape mlx_random_key
|
||||
#cgo nocallback mlx_random_key
|
||||
#cgo noescape mlx_random_split
|
||||
#cgo nocallback mlx_random_split
|
||||
#cgo noescape mlx_random_categorical_num_samples
|
||||
#cgo nocallback mlx_random_categorical_num_samples
|
||||
#cgo noescape mlx_random_normal
|
||||
#cgo nocallback mlx_random_normal
|
||||
#cgo noescape mlx_random_uniform
|
||||
#cgo nocallback mlx_random_uniform
|
||||
#cgo noescape mlx_array_eval
|
||||
#cgo nocallback mlx_array_eval
|
||||
#cgo noescape mlx_eval
|
||||
#cgo nocallback mlx_eval
|
||||
#cgo noescape mlx_async_eval
|
||||
#cgo nocallback mlx_async_eval
|
||||
#cgo noescape mlx_synchronize
|
||||
#cgo nocallback mlx_synchronize
|
||||
#cgo noescape mlx_array_new
|
||||
#cgo nocallback mlx_array_new
|
||||
#cgo noescape mlx_array_new_data
|
||||
#cgo nocallback mlx_array_new_data
|
||||
#cgo noescape mlx_array_new_float
|
||||
#cgo nocallback mlx_array_new_float
|
||||
#cgo noescape mlx_array_free
|
||||
#cgo nocallback mlx_array_free
|
||||
#cgo noescape mlx_array_size
|
||||
#cgo nocallback mlx_array_size
|
||||
#cgo noescape mlx_array_ndim
|
||||
#cgo nocallback mlx_array_ndim
|
||||
#cgo noescape mlx_array_dim
|
||||
#cgo nocallback mlx_array_dim
|
||||
#cgo noescape mlx_array_dtype
|
||||
#cgo nocallback mlx_array_dtype
|
||||
#cgo noescape mlx_array_item_int32
|
||||
#cgo nocallback mlx_array_item_int32
|
||||
#cgo noescape mlx_vector_array_new_data
|
||||
#cgo nocallback mlx_vector_array_new_data
|
||||
#cgo noescape mlx_vector_array_free
|
||||
#cgo nocallback mlx_vector_array_free
|
||||
#cgo noescape mlx_array_new_int
|
||||
#cgo nocallback mlx_array_new_int
|
||||
#cgo noescape mlx_stream_new_device
|
||||
#cgo nocallback mlx_stream_new_device
|
||||
#cgo noescape mlx_get_default_stream
|
||||
#cgo nocallback mlx_get_default_stream
|
||||
#cgo noescape mlx_set_default_stream
|
||||
#cgo nocallback mlx_set_default_stream
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
@@ -1796,7 +1611,57 @@ func ArgmaxKeepArray(logits *Array) *Array {
|
||||
var RandomState = []*Array{nil}
|
||||
var randomStateMu sync.Mutex
|
||||
|
||||
var mlxInitialized bool
|
||||
var mlxInitError error
|
||||
|
||||
// InitMLX initializes the MLX library by dynamically loading libmlxc.
|
||||
// This must be called before using any MLX functions.
|
||||
// Returns an error if the library cannot be loaded.
|
||||
func InitMLX() error {
|
||||
if mlxInitialized {
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
// Try to load the MLX dynamic library
|
||||
ret := C.mlx_dynamic_init()
|
||||
if ret != 0 {
|
||||
errMsg := C.GoString(C.mlx_dynamic_error())
|
||||
mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg)
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
// Initialize all function pointers via dlsym
|
||||
handle := C.mlx_get_handle()
|
||||
ret = C.mlx_load_functions(handle)
|
||||
if ret != 0 {
|
||||
mlxInitError = fmt.Errorf("failed to load MLX function symbols")
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
mlxInitialized = true
|
||||
mlxInitError = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsMLXAvailable returns whether MLX was successfully initialized
|
||||
func IsMLXAvailable() bool {
|
||||
return mlxInitialized && mlxInitError == nil
|
||||
}
|
||||
|
||||
// GetMLXInitError returns any error that occurred during MLX initialization
|
||||
func GetMLXInitError() error {
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Initialize MLX dynamic library first
|
||||
if err := InitMLX(); err != nil {
|
||||
// Don't panic in init - let the caller handle the error
|
||||
// Store the error for later retrieval
|
||||
mlxInitError = err
|
||||
return
|
||||
}
|
||||
|
||||
// Lock main goroutine to OS thread for CUDA context stability.
|
||||
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
|
||||
runtime.LockOSThread()
|
||||
|
||||
2337
x/imagegen/mlx/mlx.h
Normal file
2337
x/imagegen/mlx/mlx.h
Normal file
File diff suppressed because it is too large
Load Diff
144
x/imagegen/mlx/mlx_dynamic.c
Normal file
144
x/imagegen/mlx/mlx_dynamic.c
Normal file
@@ -0,0 +1,144 @@
|
||||
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
|
||||
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
|
||||
|
||||
#include "mlx_dynamic.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
#ifdef __APPLE__
|
||||
// Get path to library in same directory as executable
|
||||
static char* get_exe_relative_path(const char* libname) {
|
||||
static char path[1024];
|
||||
uint32_t size = sizeof(path);
|
||||
if (_NSGetExecutablePath(path, &size) != 0) {
|
||||
return NULL;
|
||||
}
|
||||
// Get directory of executable
|
||||
char* dir = dirname(path);
|
||||
static char fullpath[1024];
|
||||
snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname);
|
||||
return fullpath;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Try to load library from a specific path
|
||||
static int try_load_lib(const char* path) {
|
||||
if (!path) return 0;
|
||||
mlx_handle = LOAD_LIB(path);
|
||||
return mlx_handle != NULL;
|
||||
}
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
const char* lib_path = NULL;
|
||||
const char* tried_paths[8] = {0};
|
||||
int num_tried = 0;
|
||||
|
||||
#ifdef _WIN32
|
||||
// Windows: try same directory as executable
|
||||
lib_path = "libmlxc.dll";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#elif defined(__APPLE__)
|
||||
// macOS: try executable directory first
|
||||
lib_path = get_exe_relative_path("libmlxc.dylib");
|
||||
if (lib_path) {
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
}
|
||||
// Try build directory (for tests run from repo root)
|
||||
lib_path = "./build/lib/ollama/libmlxc.dylib";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
// Fallback to system paths
|
||||
lib_path = "libmlxc.dylib";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#else
|
||||
// Linux: try build directory first (for tests)
|
||||
lib_path = "./build/lib/ollama/libmlxc.so";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
// Fallback to system paths
|
||||
lib_path = "libmlxc.so";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#endif
|
||||
|
||||
// Failed to load library - build error message with all tried paths
|
||||
{
|
||||
const char* err = LIB_ERROR();
|
||||
int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. Tried: ");
|
||||
for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) {
|
||||
offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
|
||||
"%s%s", i > 0 ? ", " : "", tried_paths[i]);
|
||||
}
|
||||
if (err) {
|
||||
snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
|
||||
". Last error: %s", err);
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
|
||||
success:
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", lib_path ? lib_path : "library");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Get the library handle (for use by generated wrappers)
|
||||
void* mlx_get_handle(void) {
|
||||
return mlx_handle;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
29
x/imagegen/mlx/mlx_dynamic.h
Normal file
29
x/imagegen/mlx/mlx_dynamic.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Get the library handle (for use by generated wrappers)
|
||||
void* mlx_get_handle(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
@@ -4,9 +4,30 @@ package mlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping MLX tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive.
|
||||
func TestBasicCleanup(t *testing.T) {
|
||||
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
||||
|
||||
@@ -3,12 +3,33 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
|
||||
@@ -3,13 +3,35 @@
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
||||
func TestComputeAxisFreqs(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
|
||||
@@ -3,12 +3,34 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping nn tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
|
||||
func TestLinearNoBias(t *testing.T) {
|
||||
// Weight: [out=2, in=3] -> transposed at forward time
|
||||
|
||||
@@ -62,6 +62,12 @@ func Execute(args []string) error {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
err := mlx.InitMLX()
|
||||
if err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||
|
||||
// Check memory requirements before loading
|
||||
|
||||
@@ -62,7 +62,7 @@ func NewServer(modelName string) (*Server, error) {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the ollama-mlx executable path (in same directory as current executable)
|
||||
// Get the current executable path (we use the same binary with runner subcommand)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
@@ -70,10 +70,9 @@ func NewServer(modelName string) (*Server, error) {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
|
||||
|
||||
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
@@ -135,7 +134,7 @@ func NewServer(modelName string) (*Server, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
|
||||
slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
include(FetchContent)
|
||||
|
||||
# Read MLX version from top-level file (shared with Dockerfile)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
@@ -50,7 +54,7 @@ endif()
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG v0.4.1)
|
||||
GIT_TAG ${MLX_C_GIT_TAG})
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
|
||||
92
x/ml/backend/mlx/mlx_dynamic.c
Normal file
92
x/ml/backend/mlx/mlx_dynamic.c
Normal file
@@ -0,0 +1,92 @@
|
||||
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
|
||||
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
|
||||
|
||||
#include "mlx_dynamic.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.dylib",
|
||||
"@loader_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"@executable_path/../build/lib/ollama/libmlxc.dylib",
|
||||
"build/lib/ollama/libmlxc.dylib",
|
||||
"../build/lib/ollama/libmlxc.dylib",
|
||||
NULL
|
||||
};
|
||||
#else
|
||||
static const char* LIB_NAMES[] = {
|
||||
"libmlxc.so",
|
||||
"$ORIGIN/../build/lib/ollama/libmlxc.so",
|
||||
"build/lib/ollama/libmlxc.so",
|
||||
"../build/lib/ollama/libmlxc.so",
|
||||
NULL
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
// Try each possible library path
|
||||
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
|
||||
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
|
||||
if (mlx_handle != NULL) {
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", LIB_NAMES[i]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to load library
|
||||
const char* err = LIB_ERROR();
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. %s",
|
||||
err ? err : "Unknown error");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
26
x/ml/backend/mlx/mlx_dynamic.h
Normal file
26
x/ml/backend/mlx/mlx_dynamic.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
Reference in New Issue
Block a user