From a026277ab98d9f45aa358187a0a8f90141e9ea9c Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 9 Mar 2026 17:29:32 +0100 Subject: [PATCH] feat(mlx-distributed): add new MLX-distributed backend (#8801) * feat(mlx-distributed): add new MLX-distributed backend Add new MLX distributed backend with support for both TCP and RDMA for model sharding. This implementation ties in the discovery implementation already in place, and re-uses the same P2P mechanism for the TCP MLX-distributed inferencing. The Auto-parallel implementation is inspired by Exo's ones (who have been added to acknowledgement for the great work!) Signed-off-by: Ettore Di Giacinto * expose a CLI to facilitate backend starting Signed-off-by: Ettore Di Giacinto * feat: make manual rank0 configurable via model configs Signed-off-by: Ettore Di Giacinto * Add missing features from mlx backend Signed-off-by: Ettore Di Giacinto * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto Signed-off-by: Ettore Di Giacinto --- .github/workflows/backend.yml | 68 +++ Makefile | 10 +- README.md | 1 + backend/index.yaml | 81 +++ backend/python/mlx-distributed/Makefile | 23 + backend/python/mlx-distributed/backend.py | 509 ++++++++++++++++++ backend/python/mlx-distributed/coordinator.py | 104 ++++ backend/python/mlx-distributed/install.sh | 15 + backend/python/mlx-distributed/mlx_cache.py | 266 +++++++++ .../mlx-distributed/requirements-cpu.txt | 2 + .../mlx-distributed/requirements-cublas12.txt | 2 + .../mlx-distributed/requirements-cublas13.txt | 2 + .../mlx-distributed/requirements-l4t12.txt | 2 + .../mlx-distributed/requirements-l4t13.txt | 2 + .../mlx-distributed/requirements-mps.txt | 1 + .../python/mlx-distributed/requirements.txt | 4 + backend/python/mlx-distributed/run.sh | 11 + backend/python/mlx-distributed/sharding.py | 136 +++++ backend/python/mlx-distributed/test.py | 87 +++ backend/python/mlx-distributed/test.sh | 12 + core/application/p2p.go | 28 +- core/cli/run.go | 12 +- core/cli/worker/worker.go | 6 +- core/cli/worker/worker_mlx_common.go | 65 +++ core/cli/worker/worker_mlx_distributed.go | 62 +++ core/cli/worker/worker_p2p.go | 4 +- core/cli/worker/worker_p2p_mlx.go | 97 ++++ core/config/application_config.go | 13 +- core/explorer/discovery.go | 2 +- core/http/endpoints/localai/p2p.go | 3 +- core/http/react-ui/src/pages/P2P.jsx | 167 ++++-- core/http/routes/ui_api.go | 62 ++- core/http/views/p2p.html | 10 +- core/p2p/node.go | 5 +- core/schema/localai.go | 3 +- docs/content/features/mlx-distributed.md | 212 ++++++++ 36 files changed, 2016 insertions(+), 73 deletions(-) create mode 100644 backend/python/mlx-distributed/Makefile create mode 100644 backend/python/mlx-distributed/backend.py create mode 100644 backend/python/mlx-distributed/coordinator.py create mode 100644 backend/python/mlx-distributed/install.sh create mode 100644 backend/python/mlx-distributed/mlx_cache.py create mode 100644 backend/python/mlx-distributed/requirements-cpu.txt create mode 100644 backend/python/mlx-distributed/requirements-cublas12.txt create mode 100644 backend/python/mlx-distributed/requirements-cublas13.txt create mode 100644 backend/python/mlx-distributed/requirements-l4t12.txt create mode 100644 backend/python/mlx-distributed/requirements-l4t13.txt create mode 100644 backend/python/mlx-distributed/requirements-mps.txt create mode 100644 backend/python/mlx-distributed/requirements.txt create mode 100644 backend/python/mlx-distributed/run.sh create mode 100644 backend/python/mlx-distributed/sharding.py create mode 100644 backend/python/mlx-distributed/test.py create mode 100644 backend/python/mlx-distributed/test.sh create mode 100644 core/cli/worker/worker_mlx_common.go create mode 100644 core/cli/worker/worker_mlx_distributed.go create mode 100644 core/cli/worker/worker_p2p_mlx.go create mode 100644 docs/content/features/mlx-distributed.md diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 80e317339..a1bd2cab3 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -157,6 +157,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-cpu-mlx-distributed' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'true' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' # CUDA 12 builds - build-type: 'cublas' cuda-major-version: "12" @@ -470,6 +483,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "8" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-12-mlx-distributed' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" @@ -822,6 +848,19 @@ jobs: backend: "mlx-audio" dockerfile: "./backend/Dockerfile.python" context: "./" + - build-type: 'l4t' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/arm64' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-cuda-13-arm64-mlx-distributed' + runs-on: 'ubuntu-24.04-arm' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + ubuntu-version: '2404' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" @@ -926,6 +965,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-13-mlx-distributed' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" @@ -1423,6 +1475,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' + - build-type: 'l4t' + cuda-major-version: "12" + cuda-minor-version: "0" + platforms: 'linux/arm64' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-mlx-distributed' + runs-on: 'ubuntu-24.04-arm' + base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" + skip-drivers: 'true' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2204' # SYCL additional backends - build-type: 'intel' cuda-major-version: "" @@ -2016,6 +2081,9 @@ jobs: - backend: "mlx-audio" tag-suffix: "-metal-darwin-arm64-mlx-audio" build-type: "mps" + - backend: "mlx-distributed" + tag-suffix: "-metal-darwin-arm64-mlx-distributed" + build-type: "mps" - backend: "stablediffusion-ggml" tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml" build-type: "metal" diff --git a/Makefile b/Makefile index 54c21088a..15af1d39a 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral GOCMD=go GOTEST=$(GOCMD) test @@ -451,6 +451,10 @@ backends/mlx-audio: BACKEND=mlx-audio $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)" +backends/mlx-distributed: + BACKEND=mlx-distributed $(MAKE) build-darwin-python-backend + ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-distributed.tar)" + backends/stablediffusion-ggml-darwin: BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)" @@ -495,6 +499,7 @@ BACKEND_NEMO = nemo|python|.|false|true BACKEND_VOXCPM = voxcpm|python|.|false|true BACKEND_WHISPERX = whisperx|python|.|false|true BACKEND_ACE_STEP = ace-step|python|.|false|true +BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true # Helper function to build docker image for a backend # Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG) @@ -548,12 +553,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_NEMO))) $(eval $(call generate-docker-build-target,$(BACKEND_VOXCPM))) $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX))) $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP))) +$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED))) # Pattern rule for docker-save targets docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral +docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral docker-build-mlx-distributed ######################################################## ### Mock Backend for E2E Tests diff --git a/README.md b/README.md index 3dae0c688..699fed47b 100644 --- a/README.md +++ b/README.md @@ -482,6 +482,7 @@ LocalAI couldn't have been built without the help of great software already avai - https://github.com/EdVince/Stable-Diffusion-NCNN - https://github.com/ggerganov/whisper.cpp - https://github.com/rhasspy/piper +- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation ## 🤗 Contributors diff --git a/backend/index.yaml b/backend/index.yaml index e518170ca..392afa735 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -259,6 +259,31 @@ nvidia-l4t: "nvidia-l4t-mlx-audio" nvidia-l4t-cuda-12: "nvidia-l4t-mlx-audio" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-audio" +- &mlx-distributed + name: "mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-distributed" + icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4 + urls: + - https://github.com/ml-explore/mlx-lm + mirrors: + - localai/localai-backends:latest-metal-darwin-arm64-mlx-distributed + license: MIT + description: | + Run distributed LLM inference with MLX across multiple Apple Silicon Macs + tags: + - text-to-text + - LLM + - MLX + - distributed + capabilities: + default: "cpu-mlx-distributed" + nvidia: "cuda12-mlx-distributed" + metal: "metal-mlx-distributed" + nvidia-cuda-12: "cuda12-mlx-distributed" + nvidia-cuda-13: "cuda13-mlx-distributed" + nvidia-l4t: "nvidia-l4t-mlx-distributed" + nvidia-l4t-cuda-12: "nvidia-l4t-mlx-distributed" + nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-distributed" - &rerankers name: "rerankers" alias: "rerankers" @@ -791,6 +816,11 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio" mirrors: - localai/localai-backends:master-metal-darwin-arm64-mlx-audio +- !!merge <<: *mlx-distributed + name: "mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-distributed" + mirrors: + - localai/localai-backends:master-metal-darwin-arm64-mlx-distributed ## mlx - !!merge <<: *mlx name: "cpu-mlx" @@ -944,6 +974,57 @@ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio +## mlx-distributed +- !!merge <<: *mlx-distributed + name: "cpu-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-mlx-distributed" + mirrors: + - localai/localai-backends:latest-cpu-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cpu-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-mlx-distributed" + mirrors: + - localai/localai-backends:master-cpu-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda12-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda12-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-mlx-distributed" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-12-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda13-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda13-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-mlx-distributed" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-13-mlx-distributed +- !!merge <<: *mlx-distributed + name: "nvidia-l4t-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-mlx-distributed" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-mlx-distributed +- !!merge <<: *mlx-distributed + name: "nvidia-l4t-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-mlx-distributed" + mirrors: + - localai/localai-backends:master-nvidia-l4t-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda13-nvidia-l4t-arm64-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda13-nvidia-l4t-arm64-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed" + mirrors: + - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed - !!merge <<: *kitten-tts name: "kitten-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts" diff --git a/backend/python/mlx-distributed/Makefile b/backend/python/mlx-distributed/Makefile new file mode 100644 index 000000000..b322efbe9 --- /dev/null +++ b/backend/python/mlx-distributed/Makefile @@ -0,0 +1,23 @@ +.PHONY: mlx-distributed +mlx-distributed: + bash install.sh + +.PHONY: run +run: + @echo "Running mlx-distributed..." + bash run.sh + @echo "mlx-distributed run." + +.PHONY: test +test: + @echo "Testing mlx-distributed..." + bash test.sh + @echo "mlx-distributed tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/mlx-distributed/backend.py b/backend/python/mlx-distributed/backend.py new file mode 100644 index 000000000..b21a98070 --- /dev/null +++ b/backend/python/mlx-distributed/backend.py @@ -0,0 +1,509 @@ +#!/usr/bin/env python3 +""" +MLX Distributed Inference Backend for LocalAI. + +Two startup modes: + +1. Server mode (started by LocalAI automatically): + run.sh --addr localhost:50051 + Distributed config comes from LoadModel options or env vars. + +2. Worker mode (started by CLI for remote ranks): + run.sh --worker --hostfile hosts.json --rank 1 --backend ring + Enters a loop waiting for commands from rank 0. +""" +import asyncio +from concurrent import futures +import argparse +import json +import os +import signal +import sys +import tempfile +from typing import List + +import grpc + +import backend_pb2 +import backend_pb2_grpc + +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + + +def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None): + """Initialize MLX distributed runtime. + + Ring: MLX_HOSTFILE points to a JSON array of "ip:port" strings. Each rank + binds to its own entry (hostfile[rank]) and connects to neighbors for the + ring pipeline. + + JACCL: MLX_IBV_DEVICES points to a JSON 2D matrix of RDMA device names. + MLX_JACCL_COORDINATOR is rank 0's ip:port where it runs a TCP service that + helps all ranks establish RDMA connections. + """ + import mlx.core as mx + + if backend == "ring": + os.environ["MLX_HOSTFILE"] = hostfile + os.environ["MLX_RANK"] = str(rank) + os.environ["MLX_RING_VERBOSE"] = "1" + return mx.distributed.init(backend="ring", strict=True) + elif backend == "jaccl": + os.environ["MLX_IBV_DEVICES"] = hostfile + os.environ["MLX_RANK"] = str(rank) + if coordinator: + os.environ["MLX_JACCL_COORDINATOR"] = coordinator + return mx.distributed.init(backend="jaccl", strict=True) + else: + raise ValueError(f"Unknown backend: {backend}") + + +def is_float(s): + try: + float(s) + return True + except ValueError: + return False + + +def is_int(s): + try: + int(s) + return True + except ValueError: + return False + + +def parse_options(options): + """Parse key:value option strings into a dict.""" + result = {} + for opt in options: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + if is_float(value): + value = float(value) + elif is_int(value): + value = int(value) + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" + result[key] = value + return result + + +class BackendServicer(backend_pb2_grpc.BackendServicer): + """gRPC servicer for distributed MLX inference (runs on rank 0). + + When started by LocalAI (server mode), distributed init happens at + LoadModel time using config from model options or environment variables. + """ + + def __init__(self): + self.group = None + self.dist_backend = None + self.model = None + self.tokenizer = None + self.coordinator = None + self.options = {} + self.lru_cache = None + self.model_key = None + self.max_kv_size = None + + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + async def LoadModel(self, request, context): + try: + import mlx.core as mx + from mlx_lm import load + from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache + + print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr) + + self.options = parse_options(request.Options) + print(f"Options: {self.options}", file=sys.stderr) + + # Get distributed config from model options, falling back to env vars. + # If neither is set, run as single-node (no distributed). + hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", "")) + dist_backend = str(self.options.get("distributed_backend", + os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring"))) + # JACCL coordinator: rank 0 reads from env (set by CLI --coordinator). + # Not in model options — rank 0 is the coordinator, workers get + # the address via their own --coordinator CLI flag. + jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "") + + if hostfile: + from coordinator import DistributedCoordinator, CMD_LOAD_MODEL + from sharding import pipeline_auto_parallel + + print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr) + self.dist_backend = dist_backend + self.group = mlx_distributed_init( + rank=0, + hostfile=hostfile, + backend=dist_backend, + coordinator=jaccl_coordinator or None, + ) + self.coordinator = DistributedCoordinator(self.group) + self.coordinator.broadcast_command(CMD_LOAD_MODEL) + self.coordinator.broadcast_model_name(request.Model) + else: + print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr) + + # Build tokenizer config from request and options + tokenizer_config = {} + if request.TrustRemoteCode or self.options.get("trust_remote_code", False): + tokenizer_config["trust_remote_code"] = True + # Token overrides from options + for key in ["eos_token", "pad_token", "bos_token", "unk_token", + "sep_token", "cls_token", "mask_token"]: + if key in self.options: + tokenizer_config[key] = self.options[key] + + if tokenizer_config: + print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr) + self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) + else: + self.model, self.tokenizer = load(request.Model) + + if self.group is not None: + from sharding import pipeline_auto_parallel + self.model = pipeline_auto_parallel(self.model, self.group) + print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr) + else: + # Single-node: set up prompt cache for efficient generation + from mlx_cache import ThreadSafeLRUPromptCache + max_cache_entries = self.options.get("max_cache_entries", 10) + self.max_kv_size = self.options.get("max_kv_size", None) + self.model_key = request.Model + self.lru_cache = ThreadSafeLRUPromptCache( + max_size=max_cache_entries, + can_trim_fn=can_trim_prompt_cache, + trim_fn=trim_prompt_cache, + ) + print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr) + + except Exception as err: + print(f"[Rank 0] Error loading model: {err}", file=sys.stderr) + return backend_pb2.Result(success=False, message=f"Error loading model: {err}") + + return backend_pb2.Result(message="Model loaded successfully", success=True) + + async def Predict(self, request, context): + prompt_cache = None + cache_key = None + + try: + import mlx.core as mx + from mlx_lm import stream_generate + from mlx_lm.sample_utils import make_sampler + + prompt_text = self._prepare_prompt(request) + tokens = self._get_tokens_from_prompt(prompt_text) + + if self.coordinator: + from coordinator import CMD_GENERATE + self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) + self.coordinator.broadcast_tokens(tokens) + + max_tokens, sampler_params = self._build_generation_params(request) + + if self.coordinator: + gen_params = self.coordinator.broadcast_generation_params( + max_tokens=max_tokens, + temperature=sampler_params.get('temp', 0.6), + top_p=sampler_params.get('top_p', 1.0), + ) + max_tokens = gen_params["max_tokens"] + + sampler = make_sampler(**sampler_params) + + # Use prompt cache in single-node mode + gen_kwargs = {} + if self.lru_cache is not None: + from mlx_lm.models.cache import make_prompt_cache + cache_key = list(tokens) + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( + self.model_key, cache_key + ) + if prompt_cache is None: + prompt_cache = make_prompt_cache(self.model, self.max_kv_size) + remaining_tokens = cache_key + gen_kwargs['prompt_cache'] = prompt_cache + tokens = remaining_tokens if remaining_tokens else cache_key + + generated = [] + for response in stream_generate( + self.model, + self.tokenizer, + prompt=tokens, + max_tokens=max_tokens, + sampler=sampler, + **gen_kwargs, + ): + generated.append(response.text) + if cache_key is not None: + cache_key.append(response.token) + + if self.lru_cache is not None and cache_key is not None: + self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) + + return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8')) + + except Exception as e: + print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"Generation failed: {str(e)}") + return backend_pb2.Reply(message=bytes("", encoding='utf-8')) + + async def PredictStream(self, request, context): + prompt_cache = None + cache_key = None + + try: + import mlx.core as mx + from mlx_lm import stream_generate + from mlx_lm.sample_utils import make_sampler + + prompt_text = self._prepare_prompt(request) + tokens = self._get_tokens_from_prompt(prompt_text) + + if self.coordinator: + from coordinator import CMD_GENERATE + self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) + self.coordinator.broadcast_tokens(tokens) + + max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) + + if self.coordinator: + gen_params = self.coordinator.broadcast_generation_params( + max_tokens=max_tokens, + temperature=sampler_params.get('temp', 0.6), + top_p=sampler_params.get('top_p', 1.0), + ) + max_tokens = gen_params["max_tokens"] + + sampler = make_sampler(**sampler_params) + + # Use prompt cache in single-node mode + gen_kwargs = {} + if self.lru_cache is not None: + from mlx_lm.models.cache import make_prompt_cache + cache_key = list(tokens) + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( + self.model_key, cache_key + ) + if prompt_cache is None: + prompt_cache = make_prompt_cache(self.model, self.max_kv_size) + remaining_tokens = cache_key + gen_kwargs['prompt_cache'] = prompt_cache + tokens = remaining_tokens if remaining_tokens else cache_key + + for response in stream_generate( + self.model, + self.tokenizer, + prompt=tokens, + max_tokens=max_tokens, + sampler=sampler, + **gen_kwargs, + ): + if cache_key is not None: + cache_key.append(response.token) + yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) + + except Exception as e: + print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"Streaming failed: {str(e)}") + yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) + + finally: + if self.lru_cache is not None and prompt_cache is not None and cache_key is not None: + try: + self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) + except Exception as e: + print(f"Error inserting cache: {e}", file=sys.stderr) + + def Embedding(self, request, context): + print("Embeddings not supported in MLX distributed backend", file=sys.stderr) + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Embeddings are not supported in the MLX distributed backend.") + return backend_pb2.EmbeddingResult() + + def _prepare_prompt(self, request): + if not request.Prompt and request.UseTokenizerTemplate and request.Messages: + messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages] + return self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return request.Prompt + + def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: + tokens = self.tokenizer.encode(prompt_text) + if hasattr(tokens, 'tolist'): + return tokens.tolist() + return list(tokens) + + def _build_generation_params(self, request, default_max_tokens=200): + import mlx.core as mx + + max_tokens = getattr(request, 'Tokens', default_max_tokens) + if max_tokens == 0: + max_tokens = default_max_tokens + + temp = getattr(request, 'Temperature', 0.0) + if temp == 0.0: + temp = 0.6 + + top_p = getattr(request, 'TopP', 0.0) + if top_p == 0.0: + top_p = 1.0 + + sampler_params = { + 'temp': temp, + 'top_p': top_p, + 'min_p': getattr(request, 'MinP', 0.0), + 'top_k': getattr(request, 'TopK', 0), + 'xtc_threshold': 0.0, + 'xtc_probability': 0.0, + } + + seed = getattr(request, 'Seed', 0) + if seed != 0: + mx.random.seed(seed) + + if hasattr(self, 'options'): + if 'max_tokens' in self.options: + max_tokens = self.options['max_tokens'] + option_mapping = { + 'temp': 'temp', + 'temperature': 'temp', + 'top_p': 'top_p', + 'min_p': 'min_p', + 'top_k': 'top_k', + 'xtc_threshold': 'xtc_threshold', + 'xtc_probability': 'xtc_probability', + } + for opt_key, param_key in option_mapping.items(): + if opt_key in self.options: + sampler_params[param_key] = self.options[opt_key] + if 'seed' in self.options: + mx.random.seed(self.options['seed']) + + # XTC special tokens + xtc_special_tokens = [] + if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids: + xtc_special_tokens = list(self.tokenizer.eos_token_ids) + elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: + xtc_special_tokens = [self.tokenizer.eos_token_id] + try: + newline_tokens = self.tokenizer.encode("\n") + xtc_special_tokens.extend(newline_tokens) + except: + pass + sampler_params['xtc_special_tokens'] = xtc_special_tokens + + return max_tokens, sampler_params + + +def run_worker(group): + """Worker loop for ranks > 0. Waits for commands from rank 0.""" + from mlx_lm import load, stream_generate + from mlx_lm.sample_utils import make_sampler + from coordinator import DistributedCoordinator, CMD_LOAD_MODEL, CMD_GENERATE, CMD_SHUTDOWN + from sharding import pipeline_auto_parallel + import mlx.core as mx + + coordinator = DistributedCoordinator(group) + model = None + tokenizer = None + + print(f"[Rank {group.rank()}] Worker started, waiting for commands...", file=sys.stderr) + + while True: + cmd, payload_size = coordinator.wait_for_command() + + if cmd == CMD_LOAD_MODEL: + model_name = coordinator.broadcast_model_name() + print(f"[Rank {group.rank()}] Loading model: {model_name}", file=sys.stderr) + model, tokenizer = load(model_name) + model = pipeline_auto_parallel(model, group) + print(f"[Rank {group.rank()}] Model loaded and sharded", file=sys.stderr) + + elif cmd == CMD_GENERATE: + if model is None: + print(f"[Rank {group.rank()}] No model loaded, skipping generate", file=sys.stderr) + continue + + token_count = coordinator.broadcast_token_count(payload_size) + tokens_array = coordinator.broadcast_tokens([0] * token_count) + tokens = tokens_array.tolist() + + gen_params = coordinator.broadcast_generation_params() + + sampler = make_sampler( + temp=gen_params["temperature"], + top_p=gen_params["top_p"], + ) + + for _ in stream_generate( + model, tokenizer, + prompt=tokens, + max_tokens=gen_params["max_tokens"], + sampler=sampler, + ): + pass + + elif cmd == CMD_SHUTDOWN: + print(f"[Rank {group.rank()}] Shutting down", file=sys.stderr) + break + + +async def serve(address): + server = grpc.aio.server( + migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), + ('grpc.max_send_message_length', 50 * 1024 * 1024), + ('grpc.max_receive_message_length', 50 * 1024 * 1024), + ], + ) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5))) + + await server.start() + print(f"[Rank 0] gRPC server listening on {address}", file=sys.stderr) + await server.wait_for_termination() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MLX Distributed Backend") + parser.add_argument("--addr", default="localhost:50051", + help="gRPC listen address (used by LocalAI to send requests)") + parser.add_argument("--worker", action="store_true", + help="Run in worker mode (for remote ranks started by CLI)") + parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"], + help="ring = TCP pipeline parallelism, jaccl = RDMA tensor parallelism") + parser.add_argument("--hostfile", default=None, + help="Path to hostfile JSON (required for --worker mode)") + parser.add_argument("--rank", type=int, default=0, + help="Rank of this process (0 = server, >0 = worker)") + parser.add_argument("--coordinator", default=None, + help="JACCL coordinator ip:port (jaccl backend only)") + args = parser.parse_args() + + if args.worker: + if not args.hostfile: + print("Error: --hostfile is required in worker mode", file=sys.stderr) + sys.exit(1) + group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator) + run_worker(group) + else: + # Server mode: started by LocalAI with just --addr. + # Distributed init deferred to LoadModel (reads config from model options/env vars). + asyncio.run(serve(args.addr)) diff --git a/backend/python/mlx-distributed/coordinator.py b/backend/python/mlx-distributed/coordinator.py new file mode 100644 index 000000000..1ec686a45 --- /dev/null +++ b/backend/python/mlx-distributed/coordinator.py @@ -0,0 +1,104 @@ +""" +Distributed coordination using MLX distributed primitives. + +Rank 0 broadcasts commands and tokens to all ranks via all_sum/all_gather. +Worker ranks wait in a loop for commands from rank 0. +""" +import json +import struct + +import mlx.core as mx + + +CMD_IDLE = 0 +CMD_GENERATE = 1 +CMD_LOAD_MODEL = 2 +CMD_SHUTDOWN = -1 + + +class DistributedCoordinator: + def __init__(self, group): + self.group = group + self.rank = group.rank() + self.world_size = group.size() + + def broadcast_command(self, cmd, payload_size=0): + """Rank 0 broadcasts a command to all ranks. + + Uses all_sum with only rank 0 providing non-zero values so every + rank receives the same command array. + """ + if self.rank == 0: + cmd_array = mx.array([cmd, payload_size], dtype=mx.int32) + else: + cmd_array = mx.zeros((2,), dtype=mx.int32) + result = mx.distributed.all_sum(cmd_array, group=self.group) + mx.eval(result) + return int(result[0].item()), int(result[1].item()) + + def broadcast_tokens(self, tokens): + """Broadcast input token ids from rank 0 to all ranks. + + Rank 0 provides the real token array; other ranks provide zeros of the + same shape. ``all_sum`` ensures every rank ends up with identical data. + """ + if self.rank == 0: + token_array = mx.array(tokens, dtype=mx.int32) + else: + token_array = mx.zeros((len(tokens),), dtype=mx.int32) + result = mx.distributed.all_sum(token_array, group=self.group) + mx.eval(result) + return result + + def broadcast_token_count(self, count): + """Broadcast the number of tokens so workers can prepare a buffer.""" + if self.rank == 0: + count_array = mx.array([count], dtype=mx.int32) + else: + count_array = mx.zeros((1,), dtype=mx.int32) + result = mx.distributed.all_sum(count_array, group=self.group) + mx.eval(result) + return int(result[0].item()) + + def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0): + """Broadcast generation parameters from rank 0.""" + if self.rank == 0: + params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32) + else: + params = mx.zeros((3,), dtype=mx.float32) + result = mx.distributed.all_sum(params, group=self.group) + mx.eval(result) + return { + "max_tokens": int(result[0].item()), + "temperature": float(result[1].item()), + "top_p": float(result[2].item()), + } + + def wait_for_command(self): + """Worker ranks block here until rank 0 broadcasts a command.""" + return self.broadcast_command(CMD_IDLE, 0) + + def broadcast_model_name(self, model_name=""): + """Broadcast model name string from rank 0 to all ranks. + + Encodes the model name as int32 codepoints so it can travel via + all_sum. + """ + if self.rank == 0: + encoded = [ord(c) for c in model_name] + # First broadcast the length + length = self.broadcast_token_count(len(encoded)) + if length > 0: + name_array = mx.array(encoded, dtype=mx.int32) + result = mx.distributed.all_sum(name_array, group=self.group) + mx.eval(result) + return model_name + return "" + else: + length = self.broadcast_token_count(0) + if length > 0: + name_array = mx.zeros((length,), dtype=mx.int32) + result = mx.distributed.all_sum(name_array, group=self.group) + mx.eval(result) + return "".join(chr(int(c.item())) for c in result) + return "" diff --git a/backend/python/mlx-distributed/install.sh b/backend/python/mlx-distributed/install.sh new file mode 100644 index 000000000..253ee0c13 --- /dev/null +++ b/backend/python/mlx-distributed/install.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -e + +USE_PIP=true +PYTHON_VERSION="" + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +installRequirements diff --git a/backend/python/mlx-distributed/mlx_cache.py b/backend/python/mlx-distributed/mlx_cache.py new file mode 100644 index 000000000..6ec2bb9ba --- /dev/null +++ b/backend/python/mlx-distributed/mlx_cache.py @@ -0,0 +1,266 @@ +""" +Thread-safe LRU prompt cache for MLX-based backends. + +Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.) +with thread-safety additions for LocalAI's gRPC backend. + +Usage: + from mlx_cache import ThreadSafeLRUPromptCache + + # In LoadModel: + self.lru_cache = ThreadSafeLRUPromptCache(max_size=10) + + # In Predict/PredictStream: + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens) + # ... generate ... + self.lru_cache.insert_cache(model_key, tokens, prompt_cache) +""" +import copy +import threading +from collections import deque +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + + +@dataclass +class CacheEntry: + """A cache entry with reference counting.""" + prompt_cache: List[Any] + count: int + + +@dataclass +class SearchResult: + """Result of searching the cache trie.""" + model: Any + exact: Optional[List[int]] + shorter: Optional[List[int]] + longer: Optional[List[int]] + common_prefix: int + + +class ThreadSafeLRUPromptCache: + """ + Thread-safe LRU cache with prefix matching for prompt KV caches. + + This cache stores KV caches keyed by token sequences and supports: + - Exact match: Return the cache for the exact token sequence + - Shorter prefix match: Return a cache for a prefix of the tokens + - Longer prefix match: If a longer sequence is cached and can be trimmed + - LRU eviction: When max_size is exceeded, evict least recently used + + Thread safety is provided via a threading.Lock that protects all + cache operations. + + Args: + max_size: Maximum number of cache entries (default: 10) + can_trim_fn: Optional function to check if a cache can be trimmed + trim_fn: Optional function to trim a cache + """ + + def __init__( + self, + max_size: int = 10, + can_trim_fn: Optional[Any] = None, + trim_fn: Optional[Any] = None, + ): + self.max_size = max_size + self._cache = {} + self._lru = deque() + self._lock = threading.Lock() + + # Optional trim functions (for longer prefix reuse) + self._can_trim_fn = can_trim_fn + self._trim_fn = trim_fn + + def _search(self, model, tokens: List[int]) -> SearchResult: + """ + Search the cache for a prompt cache. Return exact or close match. + + The cache is organized as a trie where each node is keyed by a token. + This allows efficient prefix matching. + """ + if model not in self._cache: + return SearchResult(model, None, None, None, 0) + + current = self._cache[model] + last_cache_index = -1 + index = 0 + + # Traverse the trie following the token sequence + while index < len(tokens) and tokens[index] in current: + current = current[tokens[index]] + if "cache" in current: + last_cache_index = index + index += 1 + + # Exact match - no need to search for longer or shorter caches + if last_cache_index == len(tokens) - 1: + return SearchResult(model, tuple(tokens), None, None, 0) + + # Find the shorter cache (a prefix that has a cache) + # Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior. + # Single-token prefixes are not matched, which allows longer cached + # sequences to be preferred for trimming. This is acceptable because + # real prompts with chat templates are always many tokens. + shorter = None + if last_cache_index > 0: + shorter = tuple(tokens[: last_cache_index + 1]) + + # Check for caches that are longer than our token sequence + longer = None + common_prefix = index + if index > 0 and last_cache_index <= 0: + best = None + stack = [(current, [])] + while stack: + current, extra = stack.pop() + if "cache" in current: + if best is None or len(extra) < len(best): + best = extra + else: + for tok in current: + stack.append((current[tok], extra + [tok])) + if best is not None: + longer = tuple(tokens[:index] + best) + + return SearchResult(model, None, shorter, longer, common_prefix) + + def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry: + """Get a cache entry by traversing the trie.""" + current = self._cache[model] + for tok in tokens: + current = current[tok] + return current["cache"] + + def _delete(self, model, tokens: Tuple[int, ...]) -> None: + """Delete a cache entry and clean up empty trie nodes.""" + path = [self._cache[model]] + for tok in tokens: + path.append(path[-1][tok]) + del path[-1]["cache"] + + # Clean up empty nodes bottom-up + for i in reversed(range(len(tokens))): + d_prev, d, t = path[i], path[i + 1], tokens[i] + if len(d) > 0: + break + del d_prev[t] + + def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry: + """ + Extract a cache entry for exclusive use. + + If the entry has count > 1, deep copy and decrement. + If count == 1, remove from cache entirely. + """ + cache_entry = self._get(model, tokens) + if cache_entry.count == 1: + self._delete(model, tokens) + self._lru.remove((model, tokens)) + return cache_entry + + cache_entry.count -= 1 + return CacheEntry( + copy.deepcopy(cache_entry.prompt_cache), + 1, + ) + + def fetch_nearest_cache( + self, model, tokens: List[int] + ) -> Tuple[Optional[List[Any]], List[int]]: + """ + Fetch the nearest cache for the given token sequence. + + Thread-safe. Returns (cache, remaining_tokens) where: + - cache: The KV cache to use (or None if no cache found) + - remaining_tokens: Tokens that still need to be processed + + Args: + model: Model identifier (used to namespace caches) + tokens: The full token sequence for the prompt + + Returns: + Tuple of (prompt_cache, remaining_tokens) + """ + with self._lock: + tokens_tuple = tuple(tokens) + result = self._search(model, tokens) + + # Exact match - extract and return + if result.exact is not None: + cache_entry = self._extract(result.model, result.exact) + return cache_entry.prompt_cache, [] + + # Shorter prefix match - extract and return remaining + if result.shorter is not None: + cache_entry = self._extract(result.model, result.shorter) + prefix_len = len(result.shorter) + return cache_entry.prompt_cache, list(tokens[prefix_len:]) + + # Longer prefix match - try to trim if possible + if result.longer is not None and self._can_trim_fn is not None: + cache_entry = self._get(result.model, result.longer) + if self._can_trim_fn(cache_entry.prompt_cache): + # Deep copy and trim + trimmed_cache = copy.deepcopy(cache_entry.prompt_cache) + prefix = min(len(tokens) - 1, result.common_prefix) + num_to_trim = len(result.longer) - prefix + if self._trim_fn is not None: + self._trim_fn(trimmed_cache, num_to_trim) + return trimmed_cache, list(tokens[prefix:]) + + # No match found + return None, list(tokens) + + def insert_cache( + self, model, tokens: List[int], prompt_cache: List[Any] + ) -> None: + """ + Insert a cache entry after generation completes. + + Thread-safe. Handles LRU eviction if max_size is exceeded. + + Args: + model: Model identifier (used to namespace caches) + tokens: The full token sequence (prompt + generated) + prompt_cache: The KV cache to store + """ + with self._lock: + tokens_tuple = tuple(tokens) + + if model not in self._cache: + self._cache[model] = {} + current = self._cache[model] + + # Build trie path + for tok in tokens_tuple: + if tok not in current: + current[tok] = {} + current = current[tok] + + # Update or create entry + if "cache" in current: + current["cache"].count += 1 + self._lru.remove((model, tokens_tuple)) + else: + current["cache"] = CacheEntry(prompt_cache, 1) + + # Update LRU order + self._lru.append((model, tokens_tuple)) + + # Evict if over capacity + if len(self._lru) > self.max_size: + evict_model, evict_tokens = self._lru.popleft() + self._delete(evict_model, evict_tokens) + + def clear(self) -> None: + """Clear all cache entries. Thread-safe.""" + with self._lock: + self._cache.clear() + self._lru.clear() + + def __len__(self) -> int: + """Return the number of cache entries. Thread-safe.""" + with self._lock: + return len(self._lru) diff --git a/backend/python/mlx-distributed/requirements-cpu.txt b/backend/python/mlx-distributed/requirements-cpu.txt new file mode 100644 index 000000000..a381b1461 --- /dev/null +++ b/backend/python/mlx-distributed/requirements-cpu.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cpu] diff --git a/backend/python/mlx-distributed/requirements-cublas12.txt b/backend/python/mlx-distributed/requirements-cublas12.txt new file mode 100644 index 000000000..dc057533b --- /dev/null +++ b/backend/python/mlx-distributed/requirements-cublas12.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cuda12] diff --git a/backend/python/mlx-distributed/requirements-cublas13.txt b/backend/python/mlx-distributed/requirements-cublas13.txt new file mode 100644 index 000000000..40cc69453 --- /dev/null +++ b/backend/python/mlx-distributed/requirements-cublas13.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cuda13] diff --git a/backend/python/mlx-distributed/requirements-l4t12.txt b/backend/python/mlx-distributed/requirements-l4t12.txt new file mode 100644 index 000000000..dc057533b --- /dev/null +++ b/backend/python/mlx-distributed/requirements-l4t12.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cuda12] diff --git a/backend/python/mlx-distributed/requirements-l4t13.txt b/backend/python/mlx-distributed/requirements-l4t13.txt new file mode 100644 index 000000000..40cc69453 --- /dev/null +++ b/backend/python/mlx-distributed/requirements-l4t13.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cuda13] diff --git a/backend/python/mlx-distributed/requirements-mps.txt b/backend/python/mlx-distributed/requirements-mps.txt new file mode 100644 index 000000000..1b47ff0e5 --- /dev/null +++ b/backend/python/mlx-distributed/requirements-mps.txt @@ -0,0 +1 @@ +mlx-lm diff --git a/backend/python/mlx-distributed/requirements.txt b/backend/python/mlx-distributed/requirements.txt new file mode 100644 index 000000000..fe67cdb50 --- /dev/null +++ b/backend/python/mlx-distributed/requirements.txt @@ -0,0 +1,4 @@ +grpcio==1.71.0 +protobuf +certifi +setuptools diff --git a/backend/python/mlx-distributed/run.sh b/backend/python/mlx-distributed/run.sh new file mode 100644 index 000000000..8f608a541 --- /dev/null +++ b/backend/python/mlx-distributed/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/mlx-distributed/sharding.py b/backend/python/mlx-distributed/sharding.py new file mode 100644 index 000000000..0ea3dad54 --- /dev/null +++ b/backend/python/mlx-distributed/sharding.py @@ -0,0 +1,136 @@ +""" +Auto-parallelism for MLX distributed inference. + +Provides pipeline parallelism (Ring backend) by wrapping model layers with +distributed send/recv operations. Ported from exo's auto_parallel.py with +simplifications for LocalAI's use case. +""" +import mlx.core as mx +import mlx.nn as nn + + +class PipelineFirstLayer(nn.Module): + """Wraps the first layer on each rank to receive from the previous rank.""" + + def __init__(self, original_layer, rank, group): + super().__init__() + dict.__setitem__(self, "_original_layer", original_layer) + self.rank = rank + self.group = group + + @property + def original_layer(self): + return self["_original_layer"] + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self["_original_layer"], name) + + def __call__(self, x, *args, **kwargs): + if self.rank != 0: + mx.eval(x) + x = mx.distributed.recv_like(x, self.rank - 1, group=self.group) + mx.eval(x) + return self.original_layer(x, *args, **kwargs) + + +class PipelineLastLayer(nn.Module): + """Wraps the last layer on each rank to send to the next rank.""" + + def __init__(self, original_layer, rank, world_size, group): + super().__init__() + dict.__setitem__(self, "_original_layer", original_layer) + self.rank = rank + self.world_size = world_size + self.group = group + + @property + def original_layer(self): + return self["_original_layer"] + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self["_original_layer"], name) + + def __call__(self, x, *args, **kwargs): + output = self.original_layer(x, *args, **kwargs) + mx.eval(output) + if self.rank != self.world_size - 1: + output = mx.distributed.send( + output, (self.rank + 1) % self.world_size, group=self.group + ) + mx.eval(output) + # Gather output from all ranks so every rank has the final result + output = mx.distributed.all_gather(output, group=self.group)[ + -output.shape[0] : + ] + mx.eval(output) + return output + + +def get_inner_model(model): + """Get the inner model (model.model or model.transformer).""" + for attr in ("model", "transformer"): + inner = getattr(model, attr, None) + if isinstance(inner, nn.Module): + # Some models have model.model (e.g. language_model.model) + inner_inner = getattr(inner, "model", None) + if isinstance(inner_inner, nn.Module): + return inner_inner + return inner + raise ValueError("Model must have a 'model' or 'transformer' attribute") + + +def get_layers(inner_model): + """Get the list of transformer layers.""" + for attr in ("layers", "h"): + layers = getattr(inner_model, attr, None) + if layers is not None: + return layers + raise ValueError("Model must have a 'layers' or 'h' attribute") + + +def pipeline_auto_parallel(model, group, start_layer=None, end_layer=None): + """Apply pipeline parallelism to a model. + + Each rank only keeps its slice of layers. The first layer receives from + the previous rank, and the last layer sends to the next rank. + + Args: + model: The MLX model (must have model.layers or similar) + group: The distributed group + start_layer: First layer index for this rank (auto-computed if None) + end_layer: Last layer index (exclusive) for this rank (auto-computed if None) + """ + rank = group.rank() + world_size = group.size() + + inner = get_inner_model(model) + layers = list(get_layers(inner)) + total_layers = len(layers) + + if start_layer is None or end_layer is None: + layers_per_rank = total_layers // world_size + remainder = total_layers % world_size + start_layer = rank * layers_per_rank + min(rank, remainder) + end_layer = start_layer + layers_per_rank + (1 if rank < remainder else 0) + + layers = layers[start_layer:end_layer] + for layer in layers: + mx.eval(layer) + + # Wrap first and last layers + layers[0] = PipelineFirstLayer(layers[0], rank, group=group) + layers[-1] = PipelineLastLayer(layers[-1], rank, world_size, group=group) + + # Replace layers on the inner model + if hasattr(inner, "layers"): + inner.layers = layers + elif hasattr(inner, "h"): + inner.h = layers + + return model diff --git a/backend/python/mlx-distributed/test.py b/backend/python/mlx-distributed/test.py new file mode 100644 index 000000000..4cb1440ed --- /dev/null +++ b/backend/python/mlx-distributed/test.py @@ -0,0 +1,87 @@ +import unittest +import subprocess +import time + +import grpc +import backend_pb2 +import backend_pb2_grpc + + +class TestBackendServicer(unittest.TestCase): + def setUp(self): + self.service = subprocess.Popen( + ["python", "backend.py", "--addr", "localhost:50051"] + ) + time.sleep(10) + + def tearDown(self) -> None: + self.service.terminate() + self.service.wait() + + def test_server_startup(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + + def test_load_model(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_text(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + req = backend_pb2.PredictOptions(Prompt="The capital of France is") + resp = stub.Predict(req) + self.assertIsNotNone(resp.message) + except Exception as err: + print(err) + self.fail("text service failed") + finally: + self.tearDown() + + def test_sampling_params(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + + req = backend_pb2.PredictOptions( + Prompt="The capital of France is", + TopP=0.8, + Tokens=50, + Temperature=0.7, + TopK=40, + MinP=0.05, + Seed=42, + ) + resp = stub.Predict(req) + self.assertIsNotNone(resp.message) + except Exception as err: + print(err) + self.fail("sampling params service failed") + finally: + self.tearDown() diff --git a/backend/python/mlx-distributed/test.sh b/backend/python/mlx-distributed/test.sh new file mode 100644 index 000000000..f31ae54e4 --- /dev/null +++ b/backend/python/mlx-distributed/test.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/core/application/p2p.go b/core/application/p2p.go index 99527e841..8522a121d 100644 --- a/core/application/p2p.go +++ b/core/application/p2p.go @@ -84,19 +84,37 @@ func (a *Application) StartP2P() error { n = node } - // Attach a ServiceDiscoverer to the p2p node + // Attach a ServiceDiscoverer to the p2p node for llama.cpp workers xlog.Info("Starting P2P server discovery...") - if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node schema.NodeData) { + if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID), func(serviceID string, node schema.NodeData) { var tunnelAddresses []string - for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) { + for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID)) { if v.IsOnline() { tunnelAddresses = append(tunnelAddresses, v.TunnelAddress) } else { xlog.Info("Node is offline", "node", v.ID) } } - if a.applicationConfig.TunnelCallback != nil { - a.applicationConfig.TunnelCallback(tunnelAddresses) + if a.applicationConfig.LlamaCPPTunnelCallback != nil { + a.applicationConfig.LlamaCPPTunnelCallback(tunnelAddresses) + } + }, true); err != nil { + return err + } + + // Attach a ServiceDiscoverer for MLX distributed workers + xlog.Info("Starting MLX P2P worker discovery...") + if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.MLXWorkerID), func(serviceID string, node schema.NodeData) { + var tunnelAddresses []string + for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.MLXWorkerID)) { + if v.IsOnline() { + tunnelAddresses = append(tunnelAddresses, v.TunnelAddress) + } else { + xlog.Info("MLX node is offline", "node", v.ID) + } + } + if a.applicationConfig.MLXTunnelCallback != nil { + a.applicationConfig.MLXTunnelCallback(tunnelAddresses) } }, true); err != nil { return err diff --git a/core/cli/run.go b/core/cli/run.go index 57aa53207..75100930c 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -2,8 +2,10 @@ package cli import ( "context" + "encoding/json" "fmt" "os" + "path/filepath" "strings" "time" @@ -171,12 +173,18 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithMachineTag(r.MachineTag), config.WithAPIAddress(r.Address), config.WithAgentJobRetentionDays(r.AgentJobRetentionDays), - config.WithTunnelCallback(func(tunnels []string) { + config.WithLlamaCPPTunnelCallback(func(tunnels []string) { tunnelEnvVar := strings.Join(tunnels, ",") - // TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar) xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar) }), + config.WithMLXTunnelCallback(func(tunnels []string) { + hostfile := filepath.Join(os.TempDir(), "localai_mlx_hostfile.json") + data, _ := json.Marshal(tunnels) + os.WriteFile(hostfile, data, 0644) + os.Setenv("MLX_DISTRIBUTED_HOSTFILE", hostfile) + xlog.Debug("setting MLX_DISTRIBUTED_HOSTFILE", "value", hostfile, "tunnels", tunnels) + }), } if r.DisableMetricsEndpoint { diff --git a/core/cli/worker/worker.go b/core/cli/worker/worker.go index 0a636c3bf..1ddb972a0 100644 --- a/core/cli/worker/worker.go +++ b/core/cli/worker/worker.go @@ -8,6 +8,8 @@ type WorkerFlags struct { } type Worker struct { - P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"` - LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"` + P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"` + P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"` + LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"` + MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"` } diff --git a/core/cli/worker/worker_mlx_common.go b/core/cli/worker/worker_mlx_common.go new file mode 100644 index 000000000..2a1b2443f --- /dev/null +++ b/core/cli/worker/worker_mlx_common.go @@ -0,0 +1,65 @@ +package worker + +import ( + "context" + "encoding/json" + "errors" + "os/exec" + "path/filepath" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/xlog" +) + +const mlxDistributedGalleryName = "mlx-distributed" + +// findMLXDistributedBackendPath finds or installs the mlx-distributed backend +// and returns the directory containing run.sh. +func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState) (string, error) { + backends, err := gallery.ListSystemBackends(systemState) + if err != nil { + return "", err + } + + backend, ok := backends.Get(mlxDistributedGalleryName) + if !ok { + ml := model.NewModelLoader(systemState) + var gals []config.Gallery + if err := json.Unmarshal([]byte(galleries), &gals); err != nil { + xlog.Error("failed loading galleries", "error", err) + return "", err + } + if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, mlxDistributedGalleryName, nil, true); err != nil { + xlog.Error("mlx-distributed backend not found, failed to install it", "error", err) + return "", err + } + // Re-fetch after install + backends, err = gallery.ListSystemBackends(systemState) + if err != nil { + return "", err + } + backend, ok = backends.Get(mlxDistributedGalleryName) + if !ok { + return "", errors.New("mlx-distributed backend not found after install") + } + } + + backendPath := filepath.Dir(backend.RunFile) + if backendPath == "" { + return "", errors.New("mlx-distributed backend not found, install it first") + } + return backendPath, nil +} + +// buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend. +// backendPath is the directory containing run.sh (empty string to fall back to +// running backend.py directly via python3). +func buildMLXCommand(backendPath string, args ...string) *exec.Cmd { + if backendPath != "" { + return exec.Command(filepath.Join(backendPath, "run.sh"), args...) + } + return exec.Command("python3", append([]string{"backend.py"}, args...)...) +} diff --git a/core/cli/worker/worker_mlx_distributed.go b/core/cli/worker/worker_mlx_distributed.go new file mode 100644 index 000000000..e701d927a --- /dev/null +++ b/core/cli/worker/worker_mlx_distributed.go @@ -0,0 +1,62 @@ +package worker + +import ( + "fmt" + "os" + "syscall" + + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/xlog" +) + +type MLXDistributed struct { + WorkerFlags `embed:""` + Hostfile string `env:"MLX_DISTRIBUTED_HOSTFILE" required:"" help:"Path to hostfile JSON. Ring: array of 'ip:port' where entry i is rank i's listen address. JACCL: 2D matrix of RDMA device names."` + Rank int `env:"MLX_RANK" required:"" help:"Rank of this process (0 = gRPC server + ring participant, >0 = worker only)"` + Backend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend: 'ring' (TCP pipeline parallelism) or 'jaccl' (RDMA tensor parallelism)"` + Addr string `env:"MLX_DISTRIBUTED_ADDR" default:"localhost:50051" help:"gRPC API listen address for LocalAI (rank 0 only, separate from ring communication)"` + Coordinator string `env:"MLX_JACCL_COORDINATOR" default:"" help:"JACCL coordinator ip:port — rank 0's address where it accepts RDMA setup connections (all ranks must use the same value)"` +} + +func (r *MLXDistributed) Run(ctx *cliContext.Context) error { + systemState, err := system.GetSystemState( + system.WithBackendPath(r.BackendsPath), + system.WithBackendSystemPath(r.BackendsSystemPath), + ) + if err != nil { + return err + } + + backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState) + if err != nil { + return fmt.Errorf("cannot find mlx-distributed backend: %w", err) + } + + args := []string{ + "--backend", r.Backend, + "--hostfile", r.Hostfile, + "--rank", fmt.Sprint(r.Rank), + } + + if r.Rank == 0 { + args = append(args, "--addr", r.Addr) + } else { + args = append(args, "--worker") + } + + if r.Backend == "jaccl" && r.Coordinator != "" { + args = append(args, "--coordinator", r.Coordinator) + } + + cmd := buildMLXCommand(backendPath, args...) + runSh := cmd.Path + + xlog.Info("Starting mlx-distributed", "rank", r.Rank, "backend", r.Backend, "hostfile", r.Hostfile) + + return syscall.Exec( + runSh, + append([]string{runSh}, args...), + os.Environ(), + ) +} diff --git a/core/cli/worker/worker_p2p.go b/core/cli/worker/worker_p2p.go index 868357ccf..b9baf3bf3 100644 --- a/core/cli/worker/worker_p2p.go +++ b/core/cli/worker/worker_p2p.go @@ -62,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error { p = r.RunnerPort } - _, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) + _, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID)) if err != nil { return err } @@ -104,7 +104,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error { } }() - _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) + _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID)) if err != nil { return err } diff --git a/core/cli/worker/worker_p2p_mlx.go b/core/cli/worker/worker_p2p_mlx.go new file mode 100644 index 000000000..ffa1c4ef9 --- /dev/null +++ b/core/cli/worker/worker_p2p_mlx.go @@ -0,0 +1,97 @@ +package worker + +import ( + "context" + "fmt" + "os" + "time" + + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/LocalAI/pkg/signals" + "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/xlog" + "github.com/phayes/freeport" +) + +type P2PMLX struct { + WorkerFlags `embed:""` + Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"` + Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode" group:"p2p"` + MLXListenPort string `env:"MLX_LISTEN_PORT" default:"5555" help:"Port for MLX distributed communication"` + MLXBackend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend (ring or jaccl)"` +} + +func (r *P2PMLX) Run(ctx *cliContext.Context) error { + if r.Token == "" { + return fmt.Errorf("token is required") + } + + systemState, err := system.GetSystemState( + system.WithBackendPath(r.BackendsPath), + system.WithBackendSystemPath(r.BackendsSystemPath), + ) + if err != nil { + return err + } + + port, err := freeport.GetFreePort() + if err != nil { + return err + } + if r.MLXListenPort != "" { + fmt.Sscanf(r.MLXListenPort, "%d", &port) + } + + address := "127.0.0.1" + + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState) + if err != nil { + xlog.Warn("Could not find mlx-distributed backend from gallery, will try backend.py directly", "error", err) + } + + go func() { + for { + hostfile := os.Getenv("MLX_DISTRIBUTED_HOSTFILE") + if hostfile == "" { + xlog.Info("Waiting for MLX_DISTRIBUTED_HOSTFILE to be set by P2P discovery...") + time.Sleep(2 * time.Second) + continue + } + + xlog.Info("Starting mlx-distributed worker", "address", address, "port", port, "hostfile", hostfile) + + cmd := buildMLXCommand(backendPath, + "--worker", + "--backend", r.MLXBackend, + "--hostfile", hostfile, + "--rank", "0", + ) + cmd.Env = os.Environ() + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + if err := cmd.Run(); err != nil { + xlog.Error("mlx-distributed worker exited", "error", err) + } + time.Sleep(2 * time.Second) + } + }() + + _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.MLXWorkerID)) + if err != nil { + return err + } + + xlog.Info("MLX distributed worker registered on P2P network", "address", address, "port", port) + + signals.RegisterGracefulTerminationHandler(func() { + cancel() + }) + + <-c.Done() + return nil +} diff --git a/core/config/application_config.go b/core/config/application_config.go index 6f1a86622..8edc22c00 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -82,7 +82,8 @@ type ApplicationConfig struct { APIAddress string - TunnelCallback func(tunnels []string) + LlamaCPPTunnelCallback func(tunnels []string) + MLXTunnelCallback func(tunnels []string) DisableRuntimeSettings bool @@ -457,9 +458,15 @@ func WithContextSize(ctxSize int) AppOption { } } -func WithTunnelCallback(callback func(tunnels []string)) AppOption { +func WithLlamaCPPTunnelCallback(callback func(tunnels []string)) AppOption { return func(o *ApplicationConfig) { - o.TunnelCallback = callback + o.LlamaCPPTunnelCallback = callback + } +} + +func WithMLXTunnelCallback(callback func(tunnels []string)) AppOption { + return func(o *ApplicationConfig) { + o.MLXTunnelCallback = callback } } diff --git a/core/explorer/discovery.go b/core/explorer/discovery.go index 989e784d3..36a193b71 100644 --- a/core/explorer/discovery.go +++ b/core/explorer/discovery.go @@ -156,7 +156,7 @@ func (s *DiscoveryServer) retrieveNetworkData(c context.Context, ledger *blockch for d := range data { toScanForWorkers := false cd := ClusterData{} - isWorkerCluster := d == p2p.WorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.WorkerID)) + isWorkerCluster := d == p2p.LlamaCPPWorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.LlamaCPPWorkerID)) isFederatedCluster := d == p2p.FederatedID || (strings.Contains(d, "_") && strings.Contains(d, p2p.FederatedID)) switch { case isWorkerCluster: diff --git a/core/http/endpoints/localai/p2p.go b/core/http/endpoints/localai/p2p.go index afd7d048d..cc630be4f 100644 --- a/core/http/endpoints/localai/p2p.go +++ b/core/http/endpoints/localai/p2p.go @@ -15,8 +15,9 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc { // Render index return func(c echo.Context) error { return c.JSON(200, schema.P2PNodesResponse{ - Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)), + LlamaCPPNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)), FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)), + MLXNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)), }) } } diff --git a/core/http/react-ui/src/pages/P2P.jsx b/core/http/react-ui/src/pages/P2P.jsx index c1f97a4ec..f77b5db21 100644 --- a/core/http/react-ui/src/pages/P2P.jsx +++ b/core/http/react-ui/src/pages/P2P.jsx @@ -103,8 +103,9 @@ function StepNumber({ n, bg, color }) { export default function P2P() { const { addToast } = useOutletContext() const [workers, setWorkers] = useState([]) + const [mlxWorkers, setMlxWorkers] = useState([]) const [federation, setFederation] = useState([]) - const [stats, setStats] = useState({ workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 } }) + const [stats, setStats] = useState({ llama_cpp_workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 }, mlx_workers: { online: 0, total: 0 } }) const [loading, setLoading] = useState(true) const [enabled, setEnabled] = useState(false) const [token, setToken] = useState('') @@ -129,7 +130,13 @@ export default function P2P() { if (p2pToken) { if (wRes.status === 'fulfilled') { const data = wRes.value - setWorkers(data?.nodes || (Array.isArray(data) ? data : [])) + // Handle both old format ({nodes: [...]}) and new grouped format ({llama_cpp: {nodes: [...]}, mlx: {nodes: [...]}}) + if (data?.llama_cpp) { + setWorkers(data.llama_cpp.nodes || []) + setMlxWorkers(data.mlx?.nodes || []) + } else { + setWorkers(data?.nodes || (Array.isArray(data) ? data : [])) + } } if (fRes.status === 'fulfilled') { const data = fRes.value @@ -274,8 +281,10 @@ export default function P2P() { // ── P2P Enabled ── const fedOnline = stats.federated?.online ?? 0 const fedTotal = stats.federated?.total ?? 0 - const wrkOnline = stats.workers?.online ?? 0 - const wrkTotal = stats.workers?.total ?? 0 + const llamaOnline = stats.llama_cpp_workers?.online ?? 0 + const llamaTotal = stats.llama_cpp_workers?.total ?? 0 + const mlxOnline = stats.mlx_workers?.online ?? 0 + const mlxTotal = stats.mlx_workers?.total ?? 0 return (
@@ -401,7 +410,7 @@ export default function P2P() { fontSize: '0.75rem', color: activeTab === 'sharding' ? 'var(--color-accent)' : 'var(--color-text-muted)', }}> - {wrkOnline}/{wrkTotal} workers + {llamaOnline + mlxOnline}/{llamaTotal + mlxTotal} workers
@@ -562,6 +571,21 @@ export default function P2P() { borderRadius: 'var(--radius-lg)', overflow: 'hidden', }}>
+
+ + Different from federation: Federation distributes whole requests across instances. Model sharding splits a single model across machines for joint inference. +
+ + {/* ── llama.cpp RPC Workers Section ── */} +

+ + llama.cpp RPC Workers +

+ {/* Architecture diagram */}

Model weights are split across RPC workers. Each worker holds a portion of the model layers in its memory (GPU or CPU). - The LocalAI instance orchestrates inference by communicating with all workers via RPC.

-
- - Different from federation: Federation distributes whole requests across instances. Model sharding splits a single model's weights across machines for joint inference. Currently only supported with llama.cpp based models. -
- - {/* Status + nodes */} + {/* llama.cpp Status + nodes */}
-

Connected Workers

+

Connected Workers

- 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{wrkOnline} - /{wrkTotal} + 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{llamaOnline} + /{llamaTotal}
@@ -633,7 +647,7 @@ export default function P2P() { borderRadius: 'var(--radius-lg)', }}> -

No workers available

+

No llama.cpp workers connected

Start workers to see them here

) : ( @@ -645,17 +659,99 @@ export default function P2P() { )} - {/* Setup Guide */} + {/* ── MLX Distributed Workers Section ── */} +
+

+ + MLX Distributed Workers +

+ + {/* MLX Architecture diagram */} +
+
+
+
+ +
+
LocalAI
+
Rank 0
+
+
+ + Ring / JACCL +
+
+
+ {['Layers 1-16', 'Layers 17-32'].map((label, i) => ( +
+
+ +
+
{label}
+
+ ))} +
+
MLX Workers
+
Pipeline parallel
+
+
+

+ MLX distributed uses native Apple Silicon communication. All nodes execute model code simultaneously via pipeline or tensor parallelism. +

+
+ + {/* MLX Status + nodes */} +
+

Connected MLX Workers

+
+ 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{mlxOnline} + /{mlxTotal} +
+
+ + {mlxWorkers.length === 0 ? ( +
+ +

No MLX workers connected

+

Start MLX workers on Apple Silicon Macs

+
+ ) : ( +
+ {mlxWorkers.map((node, i) => ( + + ))} +
+ )} +
+ + {/* Setup Guides */}

- Start a llama.cpp RPC Worker + Setup Workers

+

llama.cpp RPC Worker

Each worker exposes its GPU/CPU memory as a shard for distributed model inference.

@@ -663,17 +759,24 @@ export default function P2P() { command={`docker run -ti --net host \\\n -e TOKEN="${token}" \\\n --name local-ai-worker \\\n localai/localai:latest-cpu worker p2p-llama-cpp-rpc`} addToast={addToast} /> -

- Run this on each machine you want to contribute as a shard. The worker will automatically join the network and advertise its resources. -

+
-

- For GPU images and all available options, see the{' '} - Container images - {' '}and{' '} - Worker docs. +

+

MLX Distributed Worker

+

+ Run on Apple Silicon Macs to participate in distributed MLX inference via pipeline parallelism. +

+ +

+ For more information, see the{' '} + MLX Distributed docs.

diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go index 446086767..68bbc1325 100644 --- a/core/http/routes/ui_api.go +++ b/core/http/routes/ui_api.go @@ -1041,11 +1041,12 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model // P2P APIs app.GET("/api/p2p/workers", func(c echo.Context) error { - nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)) + llamaNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)) + mlxNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)) - nodesJSON := make([]map[string]interface{}, 0, len(nodes)) - for _, n := range nodes { - nodesJSON = append(nodesJSON, map[string]interface{}{ + llamaJSON := make([]map[string]any, 0, len(llamaNodes)) + for _, n := range llamaNodes { + llamaJSON = append(llamaJSON, map[string]any{ "name": n.Name, "id": n.ID, "tunnelAddress": n.TunnelAddress, @@ -1055,8 +1056,27 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model }) } - return c.JSON(200, map[string]interface{}{ - "nodes": nodesJSON, + mlxJSON := make([]map[string]any, 0, len(mlxNodes)) + for _, n := range mlxNodes { + mlxJSON = append(mlxJSON, map[string]any{ + "name": n.Name, + "id": n.ID, + "tunnelAddress": n.TunnelAddress, + "serviceID": n.ServiceID, + "lastSeen": n.LastSeen, + "isOnline": n.IsOnline(), + }) + } + + return c.JSON(200, map[string]any{ + "llama_cpp": map[string]any{ + "nodes": llamaJSON, + }, + "mlx": map[string]any{ + "nodes": mlxJSON, + }, + // Keep backward-compatible "nodes" key with llama.cpp workers + "nodes": llamaJSON, }) }) @@ -1081,13 +1101,14 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model }) app.GET("/api/p2p/stats", func(c echo.Context) error { - workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)) + llamaCPPNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)) federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)) + mlxWorkerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)) - workersOnline := 0 - for _, n := range workerNodes { + llamaCPPOnline := 0 + for _, n := range llamaCPPNodes { if n.IsOnline() { - workersOnline++ + llamaCPPOnline++ } } @@ -1098,15 +1119,26 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model } } - return c.JSON(200, map[string]interface{}{ - "workers": map[string]interface{}{ - "online": workersOnline, - "total": len(workerNodes), + mlxWorkersOnline := 0 + for _, n := range mlxWorkerNodes { + if n.IsOnline() { + mlxWorkersOnline++ + } + } + + return c.JSON(200, map[string]any{ + "llama_cpp_workers": map[string]any{ + "online": llamaCPPOnline, + "total": len(llamaCPPNodes), }, - "federated": map[string]interface{}{ + "federated": map[string]any{ "online": federatedOnline, "total": len(federatedNodes), }, + "mlx_workers": map[string]any{ + "online": mlxWorkersOnline, + "total": len(mlxWorkerNodes), + }, }) }) diff --git a/core/http/views/p2p.html b/core/http/views/p2p.html index c05c488f8..b3b595cbd 100644 --- a/core/http/views/p2p.html +++ b/core/http/views/p2p.html @@ -262,8 +262,8 @@
- - / + + /

workers

@@ -469,8 +469,8 @@ docker run -ti --net host -e TOKEN="{{.P2PToken}}" --
Active Workers
- - / + + /
@@ -657,7 +657,7 @@ function p2pNetwork() { workerNodes: [], federationNodes: [], stats: { - workers: { online: 0, total: 0 }, + llama_cpp_workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 } }, diff --git a/core/p2p/node.go b/core/p2p/node.go index 78efb77ca..4996531dc 100644 --- a/core/p2p/node.go +++ b/core/p2p/node.go @@ -9,8 +9,9 @@ import ( ) const ( - defaultServicesID = "services" - WorkerID = "worker" + defaultServicesID = "services" + LlamaCPPWorkerID = "worker" + MLXWorkerID = "mlx_worker" ) var mu sync.Mutex diff --git a/core/schema/localai.go b/core/schema/localai.go index ccf3e6e10..6f98bf320 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -136,8 +136,9 @@ func (d NodeData) IsOnline() bool { } type P2PNodesResponse struct { - Nodes []NodeData `json:"nodes" yaml:"nodes"` + LlamaCPPNodes []NodeData `json:"llama_cpp_nodes" yaml:"llama_cpp_nodes"` FederatedNodes []NodeData `json:"federated_nodes" yaml:"federated_nodes"` + MLXNodes []NodeData `json:"mlx_nodes" yaml:"mlx_nodes"` } type SysInfoModel struct { diff --git a/docs/content/features/mlx-distributed.md b/docs/content/features/mlx-distributed.md new file mode 100644 index 000000000..9e20474fd --- /dev/null +++ b/docs/content/features/mlx-distributed.md @@ -0,0 +1,212 @@ ++++ +disableToc = false +title = "(experimental) MLX Distributed Inference" +weight = 18 +url = '/features/mlx-distributed/' ++++ + +MLX distributed inference allows you to split large language models across multiple Apple Silicon Macs (or other devices) for joint inference. Unlike federation (which distributes whole requests), MLX distributed splits a single model's layers across machines so they all participate in every forward pass. + +## How It Works + +MLX distributed uses **pipeline parallelism** via the Ring backend: each node holds a slice of the model's layers. During inference, activations flow from rank 0 through each subsequent rank in a pipeline. The last rank gathers the final output. + +For high-bandwidth setups (e.g., Thunderbolt-connected Macs), **JACCL** (tensor parallelism via RDMA) is also supported, where each rank holds all layers but with sharded weights. + +## Prerequisites + +- Two or more machines with MLX installed (Apple Silicon recommended) +- Network connectivity between all nodes (TCP for Ring, RDMA/Thunderbolt for JACCL) +- Same model accessible on all nodes (e.g., from Hugging Face cache) + +## Quick Start with P2P + +The simplest way to use MLX distributed is with LocalAI's P2P auto-discovery. + +### 1. Start LocalAI with P2P + +```bash +docker run -ti --net host \ + --name local-ai \ + localai/localai:latest-metal-darwin-arm64 run --p2p +``` + +This generates a network token. Copy it for the next step. + +### 2. Start MLX Workers + +On each additional Mac: + +```bash +docker run -ti --net host \ + -e TOKEN="" \ + --name local-ai-mlx-worker \ + localai/localai:latest-metal-darwin-arm64 worker p2p-mlx +``` + +Workers auto-register on the P2P network. The LocalAI server discovers them and generates a hostfile for MLX distributed. + +### 3. Use the Model + +Load any MLX-compatible model. The `mlx-distributed` backend will automatically shard it across all available ranks: + +```yaml +name: llama-distributed +backend: mlx-distributed +parameters: + model: mlx-community/Llama-3.2-1B-Instruct-4bit +``` + +## Model Configuration + +The `mlx-distributed` backend is started automatically by LocalAI like any other backend. You configure distributed inference through the model YAML file using the `options` field: + +### Ring Backend (TCP) + +```yaml +name: llama-distributed +backend: mlx-distributed +parameters: + model: mlx-community/Llama-3.2-1B-Instruct-4bit +options: + - "hostfile:/path/to/hosts.json" + - "distributed_backend:ring" +``` + +The **hostfile** is a JSON array where entry `i` is the `"ip:port"` that **rank `i` listens on** for ring communication. All ranks must use the same hostfile so they know how to reach each other. + +**Example:** Two Macs — Mac A (`192.168.1.10`) and Mac B (`192.168.1.11`): + +```json +["192.168.1.10:5555", "192.168.1.11:5555"] +``` + +- Entry 0 (`192.168.1.10:5555`) — the address rank 0 (Mac A) listens on for ring communication +- Entry 1 (`192.168.1.11:5555`) — the address rank 1 (Mac B) listens on for ring communication + +Port 5555 is arbitrary — use any available port, but it must be open in your firewall. + +### JACCL Backend (RDMA/Thunderbolt) + +```yaml +name: llama-distributed +backend: mlx-distributed +parameters: + model: mlx-community/Llama-3.2-1B-Instruct-4bit +options: + - "hostfile:/path/to/devices.json" + - "distributed_backend:jaccl" +``` + +The **device matrix** is a JSON 2D array describing the RDMA device name between each pair of ranks. The diagonal is `null` (a rank doesn't talk to itself): + +```json +[ + [null, "rdma_thunderbolt0"], + ["rdma_thunderbolt0", null] +] +``` + +JACCL requires a **coordinator** — a TCP service that helps all ranks establish RDMA connections. Rank 0 (the LocalAI machine) is always the coordinator. Workers are told the coordinator address via their `--coordinator` CLI flag (see [Starting Workers](#jaccl-workers) below). + +### Without hostfile (single-node) + +If no `hostfile` option is set and no `MLX_DISTRIBUTED_HOSTFILE` environment variable exists, the backend runs as a regular single-node MLX backend. This is useful for testing or when you don't need distributed inference. + +### Available Options + +| Option | Description | +|--------|-------------| +| `hostfile` | Path to the hostfile JSON. Ring: array of `"ip:port"`. JACCL: device matrix. | +| `distributed_backend` | `ring` (default) or `jaccl` | +| `trust_remote_code` | Allow trust_remote_code for the tokenizer | +| `max_tokens` | Override default max generation tokens | +| `temperature` / `temp` | Sampling temperature | +| `top_p` | Top-p sampling | + +These can also be set via environment variables (`MLX_DISTRIBUTED_HOSTFILE`, `MLX_DISTRIBUTED_BACKEND`) which are used as fallbacks when the model options don't specify them. + +## Starting Workers + +LocalAI starts the rank 0 process (gRPC server) automatically when the model is loaded. But you still need to start **worker processes** (ranks 1, 2, ...) on the other machines. These workers participate in every forward pass but don't serve any API — they wait for commands from rank 0. + +### Ring Workers + +On each worker machine, start a worker with the same hostfile: + +```bash +local-ai worker mlx-distributed --hostfile hosts.json --rank 1 +``` + +The `--rank` must match the worker's position in the hostfile. For example, if `hosts.json` is `["192.168.1.10:5555", "192.168.1.11:5555", "192.168.1.12:5555"]`, then: +- Rank 0: started automatically by LocalAI on `192.168.1.10` +- Rank 1: `local-ai worker mlx-distributed --hostfile hosts.json --rank 1` on `192.168.1.11` +- Rank 2: `local-ai worker mlx-distributed --hostfile hosts.json --rank 2` on `192.168.1.12` + +### JACCL Workers + +```bash +local-ai worker mlx-distributed \ + --hostfile devices.json \ + --rank 1 \ + --backend jaccl \ + --coordinator 192.168.1.10:5555 +``` + +The `--coordinator` address is the IP of the machine running LocalAI (rank 0) with any available port. Rank 0 binds the coordinator service there; workers connect to it to establish RDMA connections. + +### Worker Startup Order + +Start workers **before** loading the model in LocalAI. When LocalAI sends the LoadModel request, rank 0 initializes `mx.distributed` which tries to connect to all ranks listed in the hostfile. If workers aren't running yet, it will time out. + +## Advanced: Manual Rank 0 + +For advanced use cases, you can also run rank 0 manually as an external gRPC backend instead of letting LocalAI start it automatically: + +```bash +# On Mac A: start rank 0 manually +local-ai worker mlx-distributed --hostfile hosts.json --rank 0 --addr 192.168.1.10:50051 + +# On Mac B: start rank 1 +local-ai worker mlx-distributed --hostfile hosts.json --rank 1 + +# On any machine: start LocalAI pointing at rank 0 +local-ai run --external-grpc-backends "mlx-distributed:192.168.1.10:50051" +``` + +Then use a model config with `backend: mlx-distributed` (no need for `hostfile` in options since rank 0 already has it from CLI args). + +## CLI Reference + +### `worker mlx-distributed` + +Starts a worker or manual rank 0 process. + +| Flag | Env | Default | Description | +|------|-----|---------|-------------| +| `--hostfile` | `MLX_DISTRIBUTED_HOSTFILE` | *(required)* | Path to hostfile JSON. Ring: array of `"ip:port"` where entry `i` is rank `i`'s listen address. JACCL: device matrix of RDMA device names. | +| `--rank` | `MLX_RANK` | *(required)* | Rank of this process (0 = gRPC server + ring participant, >0 = worker only) | +| `--backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | `ring` (TCP pipeline parallelism) or `jaccl` (RDMA tensor parallelism) | +| `--addr` | `MLX_DISTRIBUTED_ADDR` | `localhost:50051` | gRPC API listen address (rank 0 only, for LocalAI or external access) | +| `--coordinator` | `MLX_JACCL_COORDINATOR` | | JACCL coordinator `ip:port` — rank 0's address for RDMA setup (all ranks must use the same value) | + +### `worker p2p-mlx` + +P2P mode — auto-discovers peers and generates hostfile. + +| Flag | Env | Default | Description | +|------|-----|---------|-------------| +| `--token` | `TOKEN` | *(required)* | P2P network token | +| `--mlx-listen-port` | `MLX_LISTEN_PORT` | `5555` | Port for MLX communication | +| `--mlx-backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | Backend type: `ring` or `jaccl` | + +## Troubleshooting + +- **All ranks download the model independently.** Each node auto-downloads from Hugging Face on first use via `mlx_lm.load()`. On rank 0 (started by LocalAI), models are downloaded to LocalAI's model directory (`HF_HOME` is set automatically). On workers, models go to the default HF cache (`~/.cache/huggingface/hub`) unless you set `HF_HOME` yourself. +- **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. Start workers before loading the model. +- **Rank assignment:** In P2P mode, rank 0 is always the LocalAI server. Worker ranks are assigned by sorting node IDs. +- **Performance:** Pipeline parallelism adds latency proportional to the number of ranks. For best results, use the fewest ranks needed to fit your model in memory. + +## Acknowledgements + +The MLX distributed auto-parallel sharding implementation is based on [exo](https://github.com/exo-explore/exo).