Files
exo/python/parts.nix
Andrei Cravtov 629c55d6ba Rename exo_pyo3_bindings to exo_rs (#2131)
## Motivation

(I think it) Makes Evan's massive PR easier to merge later on

## Changes

- Renamed exo_pyo3_bindings to exo_rs
- Upgraded versions of pyo3-based dependencies
- Renamed PyFromSwarm to just FromSwarm, and PyNetworkingHandle to just
NetworkingHandle
2026-05-31 19:23:41 +01:00

303 lines
13 KiB
Nix

{ inputs, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = ../.;
};
mkPythonSet = { pkgs, lib, self', members }:
let
inherit (pkgs.stdenv.hostPlatform) isLinux isDarwin isx86_64;
inherit (pkgs.config) cudaSupport;
inherit (pkgs) cudaPackages;
libmlx_source =
if (builtins.elem "mlx-cuda13" members.exo or [ ]) then "mlx-cuda-13"
else if (builtins.elem "mlx-cuda12" members.exo or [ ]) then "mlx-cuda-12"
else "mlx-cpu";
python = pkgs.python313;
cuda_cccl_compat = pkgs.runCommand "cuda-cccl-compat" { } ''
mkdir -p $out/include
ln -s ${cudaPackages.cuda_cccl}/include $out/include/cccl
'';
cudaLibs = with cudaPackages; [
cuda_crt
cuda_cudart
cuda_cccl
cuda_cupti
cuda_nvrtc
cuda_nvtx
cudnn
libcufile
libcublas
libcufft
libcurand
libcusolver
libcusparse
libcusparse_lt
libnvjitlink
libnvshmem
nccl
];
cudaRoot = pkgs.symlinkJoin {
name = "cuda-merged-exo";
paths = builtins.concatMap (p: [ (lib.getBin p) (lib.getLib p) (lib.getDev p) ]) (cudaLibs ++ [ cudaPackages.cuda_nvcc cuda_cccl_compat ]);
};
exoOverlay = final: prev: {
# Replace workspace exo_rs with Nix-built wheel.
# Preserve passthru so mkVirtualEnv can resolve dependency groups.
# Copy .pyi stub + py.typed marker so basedpyright can find the types.
exo-rs = pkgs.stdenv.mkDerivation {
pname = "exo-rs";
version = "0.1.0";
src = self'.packages.exo-rs;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
passthru = prev.exo-rs.passthru or { };
postInstall = ''
local siteDir=$out/${final.python.sitePackages}/exo_rs
cp ${inputs.self}/rust/exo_rs/exo_rs.pyi $siteDir/
touch $siteDir/py.typed
'';
};
};
buildSystemsOverlay = final: prev:
lib.optionalAttrs isDarwin
{
mlx = prev.mlx.overrideAttrs (old:
let
# Static dependencies included directly during compilation
gguf-tools = pkgs.fetchFromGitHub {
owner = "antirez";
repo = "gguf-tools";
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
};
metal_cpp = pkgs.fetchzip {
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
};
nanobind = pkgs.fetchFromGitHub {
owner = "wjakob";
repo = "nanobind";
rev = "v2.10.2";
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
fetchSubmodules = true;
};
in
{
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [ pkgs.cmake self'.packages.metal-toolchain ];
# TODO: non-sdk_26 support
buildInputs = (old.buildInputs or [ ])
++ [ gguf-tools pkgs.fmt pkgs.nlohmann_json pkgs.apple-sdk_26 ];
patches = [
(pkgs.replaceVars ../nix/darwin-build-fixes.patch {
sdkVersion = pkgs.apple-sdk_26.version;
inherit (self'.packages.metal-toolchain) metalVersion;
})
];
postPatch = ''
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
--replace-fail "g++" "${lib.getExe' pkgs.stdenv.cc "c++"}"
'';
DEV_RELEASE = 1;
CMAKE_ARGS = toString ([
(lib.cmakeBool "USE_SYSTEM_FMT" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${pkgs.nlohmann_json.src}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_CPU" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeOptionType "string" "CMAKE_INSTALL_LIBDIR" "lib")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${pkgs.apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${pkgs.apple-sdk_26.passthru.sdkroot}")
] ++ lib.optionals (isDarwin && isx86_64) [
(lib.cmakeBool "MLX_ENABLE_X64_MAC" true)
]);
SDKROOT = pkgs.apple-sdk_26.passthru.sdkroot;
MACOSX_DEPLOYMENT_TARGET = pkgs.apple-sdk_26.version;
});
} // lib.optionalAttrs isLinux {
mlx = prev.mlx.overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ lib.optionals cudaSupport [ pkgs.autoAddDriverRunpath ];
buildInputs = old.buildInputs ++ lib.optionals cudaSupport cudaLibs;
postInstall = ''
cp -r "${final.${libmlx_source}}/${final.python.sitePackages}/mlx" "$out/${final.python.sitePackages}/mlx/"
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
} // lib.optionalAttrs cudaSupport {
"${libmlx_source}" = prev."${libmlx_source}".overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.autoAddDriverRunpath ];
buildInputs = old.buildInputs ++ cudaLibs;
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
nvidia-cufile = prev.nvidia-cufile.overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.autoAddDriverRunpath ];
buildInputs = old.buildInputs ++ [ pkgs.rdma-core ];
});
nvidia-cusolver = prev.nvidia-cusolver.overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.autoAddDriverRunpath ];
buildInputs = old.buildInputs ++ cudaLibs;
});
nvidia-nvshmem-cu13 = prev.nvidia-nvshmem-cu13.overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.autoAddDriverRunpath ];
buildInputs = old.buildInputs ++ [ pkgs.rdma-core pkgs.pmix pkgs.libfabric pkgs.ucx pkgs.openmpi ];
});
nvidia-cusparse = prev.nvidia-cusparse.overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.autoAddDriverRunpath ];
buildInputs = old.buildInputs ++ cudaLibs;
});
torch = prev.torch.overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.autoAddDriverRunpath ];
buildInputs = old.buildInputs ++ cudaLibs;
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torchaudio = prev.torchaudio.overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.autoAddDriverRunpath ];
buildInputs = old.buildInputs ++ [ cudaPackages.cuda_cudart ];
preFixup = "addAutoPatchelfSearchPath '${final.torch}'";
});
torchvision = prev.torchvision.overrideAttrs (old: {
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.autoAddDriverRunpath ];
preFixup = "addAutoPatchelfSearchPath '${final.torch}'";
});
torch-c-dlpack-ext = prev.torch-c-dlpack-ext.overrideAttrs (old: {
buildInputs = old.buildInputs ++ cudaLibs;
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
preFixup = "addAutoPatchelfSearchPath '${final.torch}'";
});
} // lib.optionalAttrs (cudaSupport && isx86_64) {
numba = prev.numba.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [ pkgs.tbb ];
});
};
pyprojectOverlay = workspace.mkPyprojectOverlay {
sourcePreference = "wheel";
dependencies = members;
};
editableOverlay = workspace.mkEditablePyprojectOverlay {
# Use environment variable pointing to editable root directory
root = "$REPO_ROOT";
members = [ "exo" "exo-bench" ];
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
pyprojectOverlay
exoOverlay
buildSystemsOverlay
]
);
# mlx and mlx-cuda ship clashing cmake files - we dont need them at runtime anyway
venv = name: (pythonSet.mkVirtualEnv "${name}-venv" members).overrideAttrs (_: { venvSkip = [ "lib/python${python.pythonVersion}/site-packages/mlx/share/cmake/*" "lib/python${python.pythonVersion}/site-packages/build_backend.py" ]; });
mkApp =
let
libPath = lib.makeLibraryPath (
[ pkgs.stdenv.cc.cc.lib ] ++ lib.optionals cudaSupport [ cudaRoot ]
);
in
text: name: pkgs.writeShellApplication {
inherit name;
text = ''
LD_LIBRARY_PATH="${libPath}''${LD_LIBRARY_PATH:+:}''${LD_LIBRARY_PATH:-}" exec \
${lib.optionalString cudaSupport "nixglhost "} ${text}
'';
runtimeEnv = {
EXO_DASHBOARD_DIR = self'.packages.dashboard;
EXO_RESOURCES_DIR = inputs.self + /resources;
};
runtimeInputs = [
(venv name)
] ++ lib.optionals cudaSupport [ pkgs.nix-gl-host ]
++ lib.optionals isDarwin [ pkgs.macmon ];
passthru = {
venv = venv name;
evenv = ((pythonSet.overrideScope editableOverlay).mkVirtualEnv "${name}-evenv" (members // { exo = (members.exo or [ ]) ++ [ "dev" ]; exo-rs = [ ]; })).overrideAttrs (_: {
venvSkip = [ "lib/python${python.pythonVersion}/site-packages/mlx/share/cmake/*" "lib/python${python.pythonVersion}/site-packages/build_backend.py" ];
});
} // lib.optionalAttrs cudaSupport {
inherit cudaRoot;
};
};
in
{
inherit venv;
mkPythonScript = path: mkApp ''python ${path} "$@"'';
exo = mkApp ''exo "$@"'' "exo";
};
in
{
perSystem =
{ self', pkgs, unfreePkgs, lib, ... }:
let
inherit (pkgs.stdenv.hostPlatform) isLinux;
inherit (mkPythonSet { inherit self' pkgs lib; members = { exo = [ "mlx-cpu" ]; }; }) exo;
# Virtual environment with dev dependencies for testing
testVenv = (mkPythonSet {
inherit self' pkgs lib; members = {
exo = [ "dev" "mlx-cpu" ]; # Include pytest, pytest-asyncio, pytest-env
};
}).venv "exo-test";
mkBenchScript = (mkPythonSet {
inherit self' pkgs lib; members = {
exo = [ "mlx-cpu" ];
exo-bench = [ ]; # Include pytest, pytest-asyncio, pytest-env
};
}).mkPythonScript;
mkSimplePythonScript = name: path: pkgs.writeShellApplication {
inherit name;
runtimeInputs = [ pkgs.python313 ];
text = ''exec python ${path} "$@"'';
};
# if someone is particularly interested in cuda12 support in nix, please open an issue.
# until then, it's more hassle than its worth
#cuda12Set = mkPythonSet { inherit self' lib; inherit (unfreePkgs.pkgsCuda.cudaPackages_12) pkgs; members = { exo = [ "mlx-cuda12" ]; }; };
cuda13Set = mkPythonSet { inherit self' lib; inherit (unfreePkgs.pkgsCuda.cudaPackages_13) pkgs; members = { exo = [ "mlx-cuda13" ]; }; };
in
{
packages = {
inherit exo;
# for running tests in ci
exo-test-env = testVenv;
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
exo-eval = mkBenchScript "exo-eval" (inputs.self + /bench/exo_eval.py);
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
# used by ./tests/run_exo_on.sh
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
} // lib.optionalAttrs isLinux {
#exo-cuda-12 = cuda12Set.exo;
exo-cuda-13 = cuda13Set.exo;
};
checks = {
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}
touch $out
'';
typecheck = pkgs.runCommand "typecheck" { nativeBuildInputs = [ testVenv ]; } ''
cd ${inputs.self}
basedpyright
touch $out
'';
};
};
}