mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-19 12:15:07 -04:00
mlx cuda 13 (dgx spark) support (#1874)
This commit is contained in:
304
python/parts.nix
304
python/parts.nix
@@ -1,18 +1,37 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', pkgs, lib, system, ... }:
|
||||
let
|
||||
# Load workspace from uv.lock
|
||||
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
|
||||
workspaceRoot = ../.;
|
||||
};
|
||||
|
||||
mkPythonSet = { pkgs, lib, self' }:
|
||||
let
|
||||
# Load workspace from uv.lock
|
||||
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
|
||||
workspaceRoot = inputs.self;
|
||||
};
|
||||
|
||||
# Create overlay from workspace
|
||||
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
|
||||
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
inherit (pkgs.stdenv.hostPlatform) isLinux isDarwin isx86_64;
|
||||
inherit (pkgs.config) cudaSupport;
|
||||
inherit (pkgs) cudaPackages;
|
||||
cuda13Support = cudaSupport && cudaPackages.cudaMajorVersion == "13";
|
||||
libmlx_source = if cuda13Support then "mlx-cuda-13" else if cudaSupport then "mlx-cuda-12" else "mlx-cpu";
|
||||
uv_extra = if cuda13Support then "cuda13" else if cudaSupport then "cuda12" else "cpu";
|
||||
python = pkgs.python313;
|
||||
cudaLibs = with cudaPackages; [
|
||||
cuda_cudart
|
||||
cuda_cccl
|
||||
cuda_cupti
|
||||
cuda_nvrtc
|
||||
cuda_nvtx
|
||||
cudnn
|
||||
libcufile
|
||||
libcublas
|
||||
libcufft
|
||||
libcurand
|
||||
libcusolver
|
||||
libcusparse
|
||||
libcusparse_lt
|
||||
libnvjitlink
|
||||
libnvshmem
|
||||
nccl
|
||||
];
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel.
|
||||
# Preserve passthru so mkVirtualEnv can resolve dependency groups.
|
||||
@@ -32,126 +51,157 @@
|
||||
'';
|
||||
};
|
||||
};
|
||||
buildSystemsOverlay = final: prev: { } //
|
||||
lib.optionalAttrs isDarwin
|
||||
{
|
||||
mlx = prev.mlx.overrideAttrs (old:
|
||||
let
|
||||
# Static dependencies included directly during compilation
|
||||
gguf-tools = pkgs.fetchFromGitHub {
|
||||
owner = "antirez";
|
||||
repo = "gguf-tools";
|
||||
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
|
||||
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
|
||||
};
|
||||
|
||||
python = pkgs.python313;
|
||||
metal_cpp = pkgs.fetchzip {
|
||||
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
|
||||
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
|
||||
};
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
nanobind = pkgs.fetchFromGitHub {
|
||||
owner = "wjakob";
|
||||
repo = "nanobind";
|
||||
rev = "v2.10.2";
|
||||
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
|
||||
fetchSubmodules = true;
|
||||
};
|
||||
in
|
||||
{
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [ pkgs.cmake self'.packages.metal-toolchain ];
|
||||
# TODO: non-sdk_26 support
|
||||
buildInputs = (old.buildInputs or [ ])
|
||||
++ [ gguf-tools pkgs.fmt pkgs.nlohmann_json pkgs.apple-sdk_26 ];
|
||||
patches = [
|
||||
(pkgs.replaceVars ../nix/darwin-build-fixes.patch {
|
||||
sdkVersion = pkgs.apple-sdk_26.version;
|
||||
inherit (self'.packages.metal-toolchain) metalVersion;
|
||||
})
|
||||
];
|
||||
postPatch = ''
|
||||
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
|
||||
--replace-fail "g++" "${lib.getExe' pkgs.stdenv.cc "c++"}"
|
||||
'';
|
||||
|
||||
DEV_RELEASE = 1;
|
||||
CMAKE_ARGS = toString ([
|
||||
(lib.cmakeBool "USE_SYSTEM_FMT" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${pkgs.nlohmann_json.src}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_CPU" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeOptionType "string" "CMAKE_INSTALL_LIBDIR" "lib")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${pkgs.apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${pkgs.apple-sdk_26.passthru.sdkroot}")
|
||||
] ++ lib.optionals (isDarwin && isx86_64) [
|
||||
(lib.cmakeBool "MLX_ENABLE_X64_MAC" true)
|
||||
]);
|
||||
SDKROOT = pkgs.apple-sdk_26.passthru.sdkroot;
|
||||
MACOSX_DEPLOYMENT_TARGET = pkgs.apple-sdk_26.version;
|
||||
});
|
||||
} // lib.optionalAttrs isLinux {
|
||||
mlx = prev.mlx.overrideAttrs (old: {
|
||||
buildInputs = old.buildInputs ++ lib.optionals cudaSupport cudaLibs;
|
||||
autoPatchelfIgnoreMissingDeps = lib.optionals cudaSupport [ "libcuda.so.1" ];
|
||||
postInstall = (old.postInstall or "") + ''
|
||||
cp -r "${final.${libmlx_source}}/${final.python.sitePackages}/mlx" "$out/${final.python.sitePackages}/mlx/"
|
||||
'';
|
||||
});
|
||||
# rouge-score and sacrebleu don't declare setuptools as a build dependency
|
||||
rouge-score = prev.rouge-score.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
} // lib.optionalAttrs cudaSupport {
|
||||
"${libmlx_source}" = prev."${libmlx_source}".overrideAttrs (old: {
|
||||
buildInputs = old.buildInputs ++ cudaLibs;
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
sacrebleu = prev.sacrebleu.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
nvidia-cufile = prev.nvidia-cufile.overrideAttrs (old: {
|
||||
buildInputs = old.buildInputs ++ [ pkgs.rdma-core ];
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
sqlitedict = prev.sqlitedict.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
nvidia-cusolver = prev.nvidia-cusolver.overrideAttrs (old: {
|
||||
buildInputs = old.buildInputs ++ cudaLibs;
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
word2number = prev.word2number.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
nvidia-nvshmem-cu13 = prev.nvidia-nvshmem-cu13.overrideAttrs (old: {
|
||||
buildInputs = old.buildInputs ++ [ pkgs.rdma-core pkgs.pmix pkgs.libfabric pkgs.ucx pkgs.openmpi ];
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
nvidia-cusparse = prev.nvidia-cusparse.overrideAttrs (old: {
|
||||
buildInputs = old.buildInputs ++ [ cudaLibs ];
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
torch = prev.torch.overrideAttrs (old: {
|
||||
buildInputs = old.buildInputs ++ cudaLibs;
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
} // lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
# Use our pure Nix-built MLX with Metal support (macOS only)
|
||||
mlx = self'.packages.mlx;
|
||||
};
|
||||
|
||||
# Additional overlay for Linux-specific fixes (type checking env).
|
||||
# Native wheels have shared lib dependencies we don't need at type-check time.
|
||||
linuxOverlay = final: prev:
|
||||
let
|
||||
ignoreMissing = drv: drv.overrideAttrs { autoPatchelfIgnoreMissingDeps = [ "*" ]; };
|
||||
nvidiaPackages = lib.filterAttrs (name: _: lib.hasPrefix "nvidia-" name) prev;
|
||||
in
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
|
||||
buildInputs = (old.buildInputs or [ ]) ++ [
|
||||
final.nvidia-cublas
|
||||
final.nvidia-cuda-nvrtc
|
||||
final.nvidia-cudnn-cu13
|
||||
final.nvidia-nccl-cu13
|
||||
];
|
||||
preFixup = ''
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cublas}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
|
||||
'';
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
);
|
||||
|
||||
pyprojectOverlay = workspace.mkPyprojectOverlay {
|
||||
sourcePreference = "wheel";
|
||||
dependencies = { exo = [ uv_extra ]; exo-bench = [ ]; };
|
||||
};
|
||||
editableOverlay = workspace.mkEditablePyprojectOverlay {
|
||||
# Use environment variable pointing to editable root directory
|
||||
root = "$REPO_ROOT";
|
||||
members = [ "exo" "exo-bench" ];
|
||||
};
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
lib.composeManyExtensions [
|
||||
inputs.pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
pyprojectOverlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
|
||||
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
|
||||
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
|
||||
"lib/python3.13/site-packages/mlx*"
|
||||
"lib/python3.13/site-packages/nvidia*"
|
||||
];
|
||||
|
||||
# Exclude bench deps from main env (bench has its own benchVenv)
|
||||
exoDeps = removeAttrs workspace.deps.default [ "exo-bench" ];
|
||||
|
||||
exoVenv = (pythonSet.mkVirtualEnv "exo-env" exoDeps).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
exoDeps // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
)).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
mkApp = cmd: name: members: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ exoVenv ];
|
||||
runtimeEnv = {
|
||||
EXO_DASHBOARD_DIR = self'.packages.dashboard;
|
||||
EXO_RESOURCES_DIR = inputs.self + /resources;
|
||||
};
|
||||
text = ''exec python ${path} "$@"'';
|
||||
runtimeInputs = [
|
||||
# mlx and mlx-cuda ship clashing cmake files - we dont need them at runtime anyway
|
||||
((pythonSet.mkVirtualEnv "${name}-env" members).overrideAttrs (_: { venvSkip = [ "lib/python${python.pythonVersion}/site-packages/mlx/share/cmake/*" ]; }))
|
||||
]
|
||||
++ lib.optionals isDarwin [ pkgs.macmon ];
|
||||
text = "exec " + lib.optionalString cudaSupport "${lib.getExe pkgs.nix-gl-host} " + cmd;
|
||||
};
|
||||
in
|
||||
{
|
||||
inherit pythonSet;
|
||||
editablePythonSet = pythonSet.overrideScope editableOverlay;
|
||||
mkPythonScript = members: name: path: mkApp ''python ${path} "$@"'' name members;
|
||||
mkExo = name: members: mkApp ''exo "$@"'' name members;
|
||||
};
|
||||
in
|
||||
{
|
||||
perSystem =
|
||||
{ self', pkgs, unfreePkgs, lib, ... }:
|
||||
let
|
||||
inherit (pkgs.stdenv.hostPlatform) isLinux;
|
||||
inherit (mkPythonSet { inherit self' pkgs lib; }) pythonSet editablePythonSet mkPythonScript mkExo;
|
||||
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" { exo = lib.optionals isLinux [ "cpu" ]; };
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" {
|
||||
exo = [ "dev" ] ++ lib.optionals isLinux [ "cpu" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
};
|
||||
|
||||
benchVenv = pythonSet.mkVirtualEnv "exo-bench-env" {
|
||||
exo-bench = [ ];
|
||||
};
|
||||
|
||||
mkBenchScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
runtimeInputs = [ benchVenv ];
|
||||
text = ''exec python ${path} "$@"'';
|
||||
};
|
||||
mkBenchScript = mkPythonScript { exo-bench = [ ]; };
|
||||
|
||||
mkSimplePythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
@@ -159,46 +209,32 @@
|
||||
text = ''exec python ${path} "$@"'';
|
||||
};
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
}
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper script
|
||||
makeWrapper ${exoVenv}/bin/exo $out/bin/exo \
|
||||
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
--set EXO_RESOURCES_DIR ${inputs.self + /resources} \
|
||||
${lib.optionalString pkgs.stdenv.hostPlatform.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin
|
||||
{
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
} // {
|
||||
|
||||
inherit python;
|
||||
|
||||
packages = {
|
||||
exo = mkExo "exo" { exo = lib.optionals isLinux [ "cpu" ]; };
|
||||
# for devShell
|
||||
exo-venv = exoVenv;
|
||||
editableVenv = editablePythonSet.mkVirtualEnv "exo-dev-env" { exo = [ "dev" ]; };
|
||||
# for running tests in ci
|
||||
exo-test-env = testVenv;
|
||||
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
exo-eval = mkBenchScript "exo-eval" (inputs.self + /bench/exo_eval.py);
|
||||
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
|
||||
# used by ./tests/run_exo_on.sh
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
} // lib.optionalAttrs isLinux {
|
||||
exo-cuda-12 = (mkPythonSet { inherit self' lib; inherit (unfreePkgs.pkgsCuda.cudaPackages_12) pkgs; }).mkExo "exo-cuda-12" { exo = [ "cuda12" ]; };
|
||||
exo-cuda-13 = (mkPythonSet { inherit self' lib; inherit (unfreePkgs.pkgsCuda.cudaPackages_13) pkgs; }).mkExo "exo-cuda-13" { exo = [ "cuda13" ]; };
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
|
||||
# Hermetic basedpyright type checking
|
||||
typecheck = pkgs.runCommand "typecheck"
|
||||
{
|
||||
nativeBuildInputs = [
|
||||
@@ -209,7 +245,7 @@
|
||||
''
|
||||
cd ${inputs.self}
|
||||
export HOME=$TMPDIR
|
||||
basedpyright --pythonpath ${testVenv}/bin/python
|
||||
basedpyright --pythonpath ${testVenv}/bin/python --project ${inputs.self}/pyproject.toml
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user