mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
feat(mlx-distributed): add new MLX-distributed backend (#8801)
* feat(mlx-distributed): add new MLX-distributed backend Add new MLX distributed backend with support for both TCP and RDMA for model sharding. This implementation ties in the discovery implementation already in place, and re-uses the same P2P mechanism for the TCP MLX-distributed inferencing. The Auto-parallel implementation is inspired by Exo's ones (who have been added to acknowledgement for the great work!) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * expose a CLI to facilitate backend starting Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: make manual rank0 configurable via model configs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add missing features from mlx backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
734b6d391f
commit
a026277ab9
68
.github/workflows/backend.yml
vendored
68
.github/workflows/backend.yml
vendored
@@ -157,6 +157,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-mlx-distributed'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'true'
|
||||
backend: "mlx-distributed"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
# CUDA 12 builds
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
@@ -470,6 +483,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-mlx-distributed'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "mlx-distributed"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "8"
|
||||
@@ -822,6 +848,19 @@ jobs:
|
||||
backend: "mlx-audio"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-cuda-13-arm64-mlx-distributed'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
ubuntu-version: '2404'
|
||||
backend: "mlx-distributed"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -926,6 +965,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-13-mlx-distributed'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "mlx-distributed"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "13"
|
||||
cuda-minor-version: "0"
|
||||
@@ -1423,6 +1475,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-mlx-distributed'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "mlx-distributed"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2204'
|
||||
# SYCL additional backends
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
@@ -2016,6 +2081,9 @@ jobs:
|
||||
- backend: "mlx-audio"
|
||||
tag-suffix: "-metal-darwin-arm64-mlx-audio"
|
||||
build-type: "mps"
|
||||
- backend: "mlx-distributed"
|
||||
tag-suffix: "-metal-darwin-arm64-mlx-distributed"
|
||||
build-type: "mps"
|
||||
- backend: "stablediffusion-ggml"
|
||||
tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml"
|
||||
build-type: "metal"
|
||||
|
||||
10
Makefile
10
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -451,6 +451,10 @@ backends/mlx-audio:
|
||||
BACKEND=mlx-audio $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)"
|
||||
|
||||
backends/mlx-distributed:
|
||||
BACKEND=mlx-distributed $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-distributed.tar)"
|
||||
|
||||
backends/stablediffusion-ggml-darwin:
|
||||
BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)"
|
||||
@@ -495,6 +499,7 @@ BACKEND_NEMO = nemo|python|.|false|true
|
||||
BACKEND_VOXCPM = voxcpm|python|.|false|true
|
||||
BACKEND_WHISPERX = whisperx|python|.|false|true
|
||||
BACKEND_ACE_STEP = ace-step|python|.|false|true
|
||||
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
|
||||
|
||||
# Helper function to build docker image for a backend
|
||||
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
|
||||
@@ -548,12 +553,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_NEMO)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VOXCPM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
|
||||
|
||||
# Pattern rule for docker-save targets
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral docker-build-mlx-distributed
|
||||
|
||||
########################################################
|
||||
### Mock Backend for E2E Tests
|
||||
|
||||
@@ -482,6 +482,7 @@ LocalAI couldn't have been built without the help of great software already avai
|
||||
- https://github.com/EdVince/Stable-Diffusion-NCNN
|
||||
- https://github.com/ggerganov/whisper.cpp
|
||||
- https://github.com/rhasspy/piper
|
||||
- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation
|
||||
|
||||
## 🤗 Contributors
|
||||
|
||||
|
||||
@@ -259,6 +259,31 @@
|
||||
nvidia-l4t: "nvidia-l4t-mlx-audio"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-mlx-audio"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-audio"
|
||||
- &mlx-distributed
|
||||
name: "mlx-distributed"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-distributed"
|
||||
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
|
||||
urls:
|
||||
- https://github.com/ml-explore/mlx-lm
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-mlx-distributed
|
||||
license: MIT
|
||||
description: |
|
||||
Run distributed LLM inference with MLX across multiple Apple Silicon Macs
|
||||
tags:
|
||||
- text-to-text
|
||||
- LLM
|
||||
- MLX
|
||||
- distributed
|
||||
capabilities:
|
||||
default: "cpu-mlx-distributed"
|
||||
nvidia: "cuda12-mlx-distributed"
|
||||
metal: "metal-mlx-distributed"
|
||||
nvidia-cuda-12: "cuda12-mlx-distributed"
|
||||
nvidia-cuda-13: "cuda13-mlx-distributed"
|
||||
nvidia-l4t: "nvidia-l4t-mlx-distributed"
|
||||
nvidia-l4t-cuda-12: "nvidia-l4t-mlx-distributed"
|
||||
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-distributed"
|
||||
- &rerankers
|
||||
name: "rerankers"
|
||||
alias: "rerankers"
|
||||
@@ -791,6 +816,11 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-mlx-audio
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "mlx-distributed-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-mlx-distributed
|
||||
## mlx
|
||||
- !!merge <<: *mlx
|
||||
name: "cpu-mlx"
|
||||
@@ -944,6 +974,57 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio
|
||||
## mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "cpu-mlx-distributed"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "cpu-mlx-distributed-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "cuda12-mlx-distributed"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "cuda12-mlx-distributed-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "cuda13-mlx-distributed"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "cuda13-mlx-distributed-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-13-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "nvidia-l4t-mlx-distributed"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "nvidia-l4t-mlx-distributed-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "cuda13-nvidia-l4t-arm64-mlx-distributed"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed
|
||||
- !!merge <<: *mlx-distributed
|
||||
name: "cuda13-nvidia-l4t-arm64-mlx-distributed-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed
|
||||
- !!merge <<: *kitten-tts
|
||||
name: "kitten-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts"
|
||||
|
||||
23
backend/python/mlx-distributed/Makefile
Normal file
23
backend/python/mlx-distributed/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: mlx-distributed
|
||||
mlx-distributed:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run:
|
||||
@echo "Running mlx-distributed..."
|
||||
bash run.sh
|
||||
@echo "mlx-distributed run."
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
@echo "Testing mlx-distributed..."
|
||||
bash test.sh
|
||||
@echo "mlx-distributed tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
509
backend/python/mlx-distributed/backend.py
Normal file
509
backend/python/mlx-distributed/backend.py
Normal file
@@ -0,0 +1,509 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MLX Distributed Inference Backend for LocalAI.
|
||||
|
||||
Two startup modes:
|
||||
|
||||
1. Server mode (started by LocalAI automatically):
|
||||
run.sh --addr localhost:50051
|
||||
Distributed config comes from LoadModel options or env vars.
|
||||
|
||||
2. Worker mode (started by CLI for remote ranks):
|
||||
run.sh --worker --hostfile hosts.json --rank 1 --backend ring
|
||||
Enters a loop waiting for commands from rank 0.
|
||||
"""
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None):
|
||||
"""Initialize MLX distributed runtime.
|
||||
|
||||
Ring: MLX_HOSTFILE points to a JSON array of "ip:port" strings. Each rank
|
||||
binds to its own entry (hostfile[rank]) and connects to neighbors for the
|
||||
ring pipeline.
|
||||
|
||||
JACCL: MLX_IBV_DEVICES points to a JSON 2D matrix of RDMA device names.
|
||||
MLX_JACCL_COORDINATOR is rank 0's ip:port where it runs a TCP service that
|
||||
helps all ranks establish RDMA connections.
|
||||
"""
|
||||
import mlx.core as mx
|
||||
|
||||
if backend == "ring":
|
||||
os.environ["MLX_HOSTFILE"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
return mx.distributed.init(backend="ring", strict=True)
|
||||
elif backend == "jaccl":
|
||||
os.environ["MLX_IBV_DEVICES"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
if coordinator:
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = coordinator
|
||||
return mx.distributed.init(backend="jaccl", strict=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
|
||||
def is_float(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_int(s):
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def parse_options(options):
|
||||
"""Parse key:value option strings into a dict."""
|
||||
result = {}
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""gRPC servicer for distributed MLX inference (runs on rank 0).
|
||||
|
||||
When started by LocalAI (server mode), distributed init happens at
|
||||
LoadModel time using config from model options or environment variables.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.group = None
|
||||
self.dist_backend = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.coordinator = None
|
||||
self.options = {}
|
||||
self.lru_cache = None
|
||||
self.model_key = None
|
||||
self.max_kv_size = None
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
async def LoadModel(self, request, context):
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load
|
||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||
|
||||
print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr)
|
||||
|
||||
self.options = parse_options(request.Options)
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
# Get distributed config from model options, falling back to env vars.
|
||||
# If neither is set, run as single-node (no distributed).
|
||||
hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", ""))
|
||||
dist_backend = str(self.options.get("distributed_backend",
|
||||
os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring")))
|
||||
# JACCL coordinator: rank 0 reads from env (set by CLI --coordinator).
|
||||
# Not in model options — rank 0 is the coordinator, workers get
|
||||
# the address via their own --coordinator CLI flag.
|
||||
jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "")
|
||||
|
||||
if hostfile:
|
||||
from coordinator import DistributedCoordinator, CMD_LOAD_MODEL
|
||||
from sharding import pipeline_auto_parallel
|
||||
|
||||
print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr)
|
||||
self.dist_backend = dist_backend
|
||||
self.group = mlx_distributed_init(
|
||||
rank=0,
|
||||
hostfile=hostfile,
|
||||
backend=dist_backend,
|
||||
coordinator=jaccl_coordinator or None,
|
||||
)
|
||||
self.coordinator = DistributedCoordinator(self.group)
|
||||
self.coordinator.broadcast_command(CMD_LOAD_MODEL)
|
||||
self.coordinator.broadcast_model_name(request.Model)
|
||||
else:
|
||||
print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr)
|
||||
|
||||
# Build tokenizer config from request and options
|
||||
tokenizer_config = {}
|
||||
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
# Token overrides from options
|
||||
for key in ["eos_token", "pad_token", "bos_token", "unk_token",
|
||||
"sep_token", "cls_token", "mask_token"]:
|
||||
if key in self.options:
|
||||
tokenizer_config[key] = self.options[key]
|
||||
|
||||
if tokenizer_config:
|
||||
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
|
||||
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
|
||||
else:
|
||||
self.model, self.tokenizer = load(request.Model)
|
||||
|
||||
if self.group is not None:
|
||||
from sharding import pipeline_auto_parallel
|
||||
self.model = pipeline_auto_parallel(self.model, self.group)
|
||||
print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr)
|
||||
else:
|
||||
# Single-node: set up prompt cache for efficient generation
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
max_cache_entries = self.options.get("max_cache_entries", 10)
|
||||
self.max_kv_size = self.options.get("max_kv_size", None)
|
||||
self.model_key = request.Model
|
||||
self.lru_cache = ThreadSafeLRUPromptCache(
|
||||
max_size=max_cache_entries,
|
||||
can_trim_fn=can_trim_prompt_cache,
|
||||
trim_fn=trim_prompt_cache,
|
||||
)
|
||||
print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
print(f"[Rank 0] Error loading model: {err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading model: {err}")
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
async def Predict(self, request, context):
|
||||
prompt_cache = None
|
||||
cache_key = None
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
tokens = self._get_tokens_from_prompt(prompt_text)
|
||||
|
||||
if self.coordinator:
|
||||
from coordinator import CMD_GENERATE
|
||||
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
|
||||
self.coordinator.broadcast_tokens(tokens)
|
||||
|
||||
max_tokens, sampler_params = self._build_generation_params(request)
|
||||
|
||||
if self.coordinator:
|
||||
gen_params = self.coordinator.broadcast_generation_params(
|
||||
max_tokens=max_tokens,
|
||||
temperature=sampler_params.get('temp', 0.6),
|
||||
top_p=sampler_params.get('top_p', 1.0),
|
||||
)
|
||||
max_tokens = gen_params["max_tokens"]
|
||||
|
||||
sampler = make_sampler(**sampler_params)
|
||||
|
||||
# Use prompt cache in single-node mode
|
||||
gen_kwargs = {}
|
||||
if self.lru_cache is not None:
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
cache_key = list(tokens)
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
|
||||
self.model_key, cache_key
|
||||
)
|
||||
if prompt_cache is None:
|
||||
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
|
||||
remaining_tokens = cache_key
|
||||
gen_kwargs['prompt_cache'] = prompt_cache
|
||||
tokens = remaining_tokens if remaining_tokens else cache_key
|
||||
|
||||
generated = []
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
**gen_kwargs,
|
||||
):
|
||||
generated.append(response.text)
|
||||
if cache_key is not None:
|
||||
cache_key.append(response.token)
|
||||
|
||||
if self.lru_cache is not None and cache_key is not None:
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Generation failed: {str(e)}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
prompt_cache = None
|
||||
cache_key = None
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
tokens = self._get_tokens_from_prompt(prompt_text)
|
||||
|
||||
if self.coordinator:
|
||||
from coordinator import CMD_GENERATE
|
||||
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
|
||||
self.coordinator.broadcast_tokens(tokens)
|
||||
|
||||
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
|
||||
if self.coordinator:
|
||||
gen_params = self.coordinator.broadcast_generation_params(
|
||||
max_tokens=max_tokens,
|
||||
temperature=sampler_params.get('temp', 0.6),
|
||||
top_p=sampler_params.get('top_p', 1.0),
|
||||
)
|
||||
max_tokens = gen_params["max_tokens"]
|
||||
|
||||
sampler = make_sampler(**sampler_params)
|
||||
|
||||
# Use prompt cache in single-node mode
|
||||
gen_kwargs = {}
|
||||
if self.lru_cache is not None:
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
cache_key = list(tokens)
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
|
||||
self.model_key, cache_key
|
||||
)
|
||||
if prompt_cache is None:
|
||||
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
|
||||
remaining_tokens = cache_key
|
||||
gen_kwargs['prompt_cache'] = prompt_cache
|
||||
tokens = remaining_tokens if remaining_tokens else cache_key
|
||||
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
**gen_kwargs,
|
||||
):
|
||||
if cache_key is not None:
|
||||
cache_key.append(response.token)
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Streaming failed: {str(e)}")
|
||||
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
finally:
|
||||
if self.lru_cache is not None and prompt_cache is not None and cache_key is not None:
|
||||
try:
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
except Exception as e:
|
||||
print(f"Error inserting cache: {e}", file=sys.stderr)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
print("Embeddings not supported in MLX distributed backend", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Embeddings are not supported in the MLX distributed backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
def _prepare_prompt(self, request):
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages]
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return request.Prompt
|
||||
|
||||
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
|
||||
tokens = self.tokenizer.encode(prompt_text)
|
||||
if hasattr(tokens, 'tolist'):
|
||||
return tokens.tolist()
|
||||
return list(tokens)
|
||||
|
||||
def _build_generation_params(self, request, default_max_tokens=200):
|
||||
import mlx.core as mx
|
||||
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens)
|
||||
if max_tokens == 0:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
temp = getattr(request, 'Temperature', 0.0)
|
||||
if temp == 0.0:
|
||||
temp = 0.6
|
||||
|
||||
top_p = getattr(request, 'TopP', 0.0)
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0
|
||||
|
||||
sampler_params = {
|
||||
'temp': temp,
|
||||
'top_p': top_p,
|
||||
'min_p': getattr(request, 'MinP', 0.0),
|
||||
'top_k': getattr(request, 'TopK', 0),
|
||||
'xtc_threshold': 0.0,
|
||||
'xtc_probability': 0.0,
|
||||
}
|
||||
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
|
||||
if hasattr(self, 'options'):
|
||||
if 'max_tokens' in self.options:
|
||||
max_tokens = self.options['max_tokens']
|
||||
option_mapping = {
|
||||
'temp': 'temp',
|
||||
'temperature': 'temp',
|
||||
'top_p': 'top_p',
|
||||
'min_p': 'min_p',
|
||||
'top_k': 'top_k',
|
||||
'xtc_threshold': 'xtc_threshold',
|
||||
'xtc_probability': 'xtc_probability',
|
||||
}
|
||||
for opt_key, param_key in option_mapping.items():
|
||||
if opt_key in self.options:
|
||||
sampler_params[param_key] = self.options[opt_key]
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
# XTC special tokens
|
||||
xtc_special_tokens = []
|
||||
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
|
||||
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
|
||||
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
||||
xtc_special_tokens = [self.tokenizer.eos_token_id]
|
||||
try:
|
||||
newline_tokens = self.tokenizer.encode("\n")
|
||||
xtc_special_tokens.extend(newline_tokens)
|
||||
except:
|
||||
pass
|
||||
sampler_params['xtc_special_tokens'] = xtc_special_tokens
|
||||
|
||||
return max_tokens, sampler_params
|
||||
|
||||
|
||||
def run_worker(group):
|
||||
"""Worker loop for ranks > 0. Waits for commands from rank 0."""
|
||||
from mlx_lm import load, stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from coordinator import DistributedCoordinator, CMD_LOAD_MODEL, CMD_GENERATE, CMD_SHUTDOWN
|
||||
from sharding import pipeline_auto_parallel
|
||||
import mlx.core as mx
|
||||
|
||||
coordinator = DistributedCoordinator(group)
|
||||
model = None
|
||||
tokenizer = None
|
||||
|
||||
print(f"[Rank {group.rank()}] Worker started, waiting for commands...", file=sys.stderr)
|
||||
|
||||
while True:
|
||||
cmd, payload_size = coordinator.wait_for_command()
|
||||
|
||||
if cmd == CMD_LOAD_MODEL:
|
||||
model_name = coordinator.broadcast_model_name()
|
||||
print(f"[Rank {group.rank()}] Loading model: {model_name}", file=sys.stderr)
|
||||
model, tokenizer = load(model_name)
|
||||
model = pipeline_auto_parallel(model, group)
|
||||
print(f"[Rank {group.rank()}] Model loaded and sharded", file=sys.stderr)
|
||||
|
||||
elif cmd == CMD_GENERATE:
|
||||
if model is None:
|
||||
print(f"[Rank {group.rank()}] No model loaded, skipping generate", file=sys.stderr)
|
||||
continue
|
||||
|
||||
token_count = coordinator.broadcast_token_count(payload_size)
|
||||
tokens_array = coordinator.broadcast_tokens([0] * token_count)
|
||||
tokens = tokens_array.tolist()
|
||||
|
||||
gen_params = coordinator.broadcast_generation_params()
|
||||
|
||||
sampler = make_sampler(
|
||||
temp=gen_params["temperature"],
|
||||
top_p=gen_params["top_p"],
|
||||
)
|
||||
|
||||
for _ in stream_generate(
|
||||
model, tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
sampler=sampler,
|
||||
):
|
||||
pass
|
||||
|
||||
elif cmd == CMD_SHUTDOWN:
|
||||
print(f"[Rank {group.rank()}] Shutting down", file=sys.stderr)
|
||||
break
|
||||
|
||||
|
||||
async def serve(address):
|
||||
server = grpc.aio.server(
|
||||
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
|
||||
|
||||
await server.start()
|
||||
print(f"[Rank 0] gRPC server listening on {address}", file=sys.stderr)
|
||||
await server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MLX Distributed Backend")
|
||||
parser.add_argument("--addr", default="localhost:50051",
|
||||
help="gRPC listen address (used by LocalAI to send requests)")
|
||||
parser.add_argument("--worker", action="store_true",
|
||||
help="Run in worker mode (for remote ranks started by CLI)")
|
||||
parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"],
|
||||
help="ring = TCP pipeline parallelism, jaccl = RDMA tensor parallelism")
|
||||
parser.add_argument("--hostfile", default=None,
|
||||
help="Path to hostfile JSON (required for --worker mode)")
|
||||
parser.add_argument("--rank", type=int, default=0,
|
||||
help="Rank of this process (0 = server, >0 = worker)")
|
||||
parser.add_argument("--coordinator", default=None,
|
||||
help="JACCL coordinator ip:port (jaccl backend only)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.worker:
|
||||
if not args.hostfile:
|
||||
print("Error: --hostfile is required in worker mode", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator)
|
||||
run_worker(group)
|
||||
else:
|
||||
# Server mode: started by LocalAI with just --addr.
|
||||
# Distributed init deferred to LoadModel (reads config from model options/env vars).
|
||||
asyncio.run(serve(args.addr))
|
||||
104
backend/python/mlx-distributed/coordinator.py
Normal file
104
backend/python/mlx-distributed/coordinator.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Distributed coordination using MLX distributed primitives.
|
||||
|
||||
Rank 0 broadcasts commands and tokens to all ranks via all_sum/all_gather.
|
||||
Worker ranks wait in a loop for commands from rank 0.
|
||||
"""
|
||||
import json
|
||||
import struct
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
CMD_IDLE = 0
|
||||
CMD_GENERATE = 1
|
||||
CMD_LOAD_MODEL = 2
|
||||
CMD_SHUTDOWN = -1
|
||||
|
||||
|
||||
class DistributedCoordinator:
|
||||
def __init__(self, group):
|
||||
self.group = group
|
||||
self.rank = group.rank()
|
||||
self.world_size = group.size()
|
||||
|
||||
def broadcast_command(self, cmd, payload_size=0):
|
||||
"""Rank 0 broadcasts a command to all ranks.
|
||||
|
||||
Uses all_sum with only rank 0 providing non-zero values so every
|
||||
rank receives the same command array.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
cmd_array = mx.array([cmd, payload_size], dtype=mx.int32)
|
||||
else:
|
||||
cmd_array = mx.zeros((2,), dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(cmd_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return int(result[0].item()), int(result[1].item())
|
||||
|
||||
def broadcast_tokens(self, tokens):
|
||||
"""Broadcast input token ids from rank 0 to all ranks.
|
||||
|
||||
Rank 0 provides the real token array; other ranks provide zeros of the
|
||||
same shape. ``all_sum`` ensures every rank ends up with identical data.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
token_array = mx.array(tokens, dtype=mx.int32)
|
||||
else:
|
||||
token_array = mx.zeros((len(tokens),), dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(token_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return result
|
||||
|
||||
def broadcast_token_count(self, count):
|
||||
"""Broadcast the number of tokens so workers can prepare a buffer."""
|
||||
if self.rank == 0:
|
||||
count_array = mx.array([count], dtype=mx.int32)
|
||||
else:
|
||||
count_array = mx.zeros((1,), dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(count_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return int(result[0].item())
|
||||
|
||||
def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0):
|
||||
"""Broadcast generation parameters from rank 0."""
|
||||
if self.rank == 0:
|
||||
params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32)
|
||||
else:
|
||||
params = mx.zeros((3,), dtype=mx.float32)
|
||||
result = mx.distributed.all_sum(params, group=self.group)
|
||||
mx.eval(result)
|
||||
return {
|
||||
"max_tokens": int(result[0].item()),
|
||||
"temperature": float(result[1].item()),
|
||||
"top_p": float(result[2].item()),
|
||||
}
|
||||
|
||||
def wait_for_command(self):
|
||||
"""Worker ranks block here until rank 0 broadcasts a command."""
|
||||
return self.broadcast_command(CMD_IDLE, 0)
|
||||
|
||||
def broadcast_model_name(self, model_name=""):
|
||||
"""Broadcast model name string from rank 0 to all ranks.
|
||||
|
||||
Encodes the model name as int32 codepoints so it can travel via
|
||||
all_sum.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
encoded = [ord(c) for c in model_name]
|
||||
# First broadcast the length
|
||||
length = self.broadcast_token_count(len(encoded))
|
||||
if length > 0:
|
||||
name_array = mx.array(encoded, dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(name_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return model_name
|
||||
return ""
|
||||
else:
|
||||
length = self.broadcast_token_count(0)
|
||||
if length > 0:
|
||||
name_array = mx.zeros((length,), dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(name_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return "".join(chr(int(c.item())) for c in result)
|
||||
return ""
|
||||
15
backend/python/mlx-distributed/install.sh
Normal file
15
backend/python/mlx-distributed/install.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
USE_PIP=true
|
||||
PYTHON_VERSION=""
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
266
backend/python/mlx-distributed/mlx_cache.py
Normal file
266
backend/python/mlx-distributed/mlx_cache.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Thread-safe LRU prompt cache for MLX-based backends.
|
||||
|
||||
Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.)
|
||||
with thread-safety additions for LocalAI's gRPC backend.
|
||||
|
||||
Usage:
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
|
||||
# In LoadModel:
|
||||
self.lru_cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
# In Predict/PredictStream:
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens)
|
||||
# ... generate ...
|
||||
self.lru_cache.insert_cache(model_key, tokens, prompt_cache)
|
||||
"""
|
||||
import copy
|
||||
import threading
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cache entry with reference counting."""
|
||||
prompt_cache: List[Any]
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Result of searching the cache trie."""
|
||||
model: Any
|
||||
exact: Optional[List[int]]
|
||||
shorter: Optional[List[int]]
|
||||
longer: Optional[List[int]]
|
||||
common_prefix: int
|
||||
|
||||
|
||||
class ThreadSafeLRUPromptCache:
|
||||
"""
|
||||
Thread-safe LRU cache with prefix matching for prompt KV caches.
|
||||
|
||||
This cache stores KV caches keyed by token sequences and supports:
|
||||
- Exact match: Return the cache for the exact token sequence
|
||||
- Shorter prefix match: Return a cache for a prefix of the tokens
|
||||
- Longer prefix match: If a longer sequence is cached and can be trimmed
|
||||
- LRU eviction: When max_size is exceeded, evict least recently used
|
||||
|
||||
Thread safety is provided via a threading.Lock that protects all
|
||||
cache operations.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of cache entries (default: 10)
|
||||
can_trim_fn: Optional function to check if a cache can be trimmed
|
||||
trim_fn: Optional function to trim a cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 10,
|
||||
can_trim_fn: Optional[Any] = None,
|
||||
trim_fn: Optional[Any] = None,
|
||||
):
|
||||
self.max_size = max_size
|
||||
self._cache = {}
|
||||
self._lru = deque()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Optional trim functions (for longer prefix reuse)
|
||||
self._can_trim_fn = can_trim_fn
|
||||
self._trim_fn = trim_fn
|
||||
|
||||
def _search(self, model, tokens: List[int]) -> SearchResult:
|
||||
"""
|
||||
Search the cache for a prompt cache. Return exact or close match.
|
||||
|
||||
The cache is organized as a trie where each node is keyed by a token.
|
||||
This allows efficient prefix matching.
|
||||
"""
|
||||
if model not in self._cache:
|
||||
return SearchResult(model, None, None, None, 0)
|
||||
|
||||
current = self._cache[model]
|
||||
last_cache_index = -1
|
||||
index = 0
|
||||
|
||||
# Traverse the trie following the token sequence
|
||||
while index < len(tokens) and tokens[index] in current:
|
||||
current = current[tokens[index]]
|
||||
if "cache" in current:
|
||||
last_cache_index = index
|
||||
index += 1
|
||||
|
||||
# Exact match - no need to search for longer or shorter caches
|
||||
if last_cache_index == len(tokens) - 1:
|
||||
return SearchResult(model, tuple(tokens), None, None, 0)
|
||||
|
||||
# Find the shorter cache (a prefix that has a cache)
|
||||
# Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior.
|
||||
# Single-token prefixes are not matched, which allows longer cached
|
||||
# sequences to be preferred for trimming. This is acceptable because
|
||||
# real prompts with chat templates are always many tokens.
|
||||
shorter = None
|
||||
if last_cache_index > 0:
|
||||
shorter = tuple(tokens[: last_cache_index + 1])
|
||||
|
||||
# Check for caches that are longer than our token sequence
|
||||
longer = None
|
||||
common_prefix = index
|
||||
if index > 0 and last_cache_index <= 0:
|
||||
best = None
|
||||
stack = [(current, [])]
|
||||
while stack:
|
||||
current, extra = stack.pop()
|
||||
if "cache" in current:
|
||||
if best is None or len(extra) < len(best):
|
||||
best = extra
|
||||
else:
|
||||
for tok in current:
|
||||
stack.append((current[tok], extra + [tok]))
|
||||
if best is not None:
|
||||
longer = tuple(tokens[:index] + best)
|
||||
|
||||
return SearchResult(model, None, shorter, longer, common_prefix)
|
||||
|
||||
def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
||||
"""Get a cache entry by traversing the trie."""
|
||||
current = self._cache[model]
|
||||
for tok in tokens:
|
||||
current = current[tok]
|
||||
return current["cache"]
|
||||
|
||||
def _delete(self, model, tokens: Tuple[int, ...]) -> None:
|
||||
"""Delete a cache entry and clean up empty trie nodes."""
|
||||
path = [self._cache[model]]
|
||||
for tok in tokens:
|
||||
path.append(path[-1][tok])
|
||||
del path[-1]["cache"]
|
||||
|
||||
# Clean up empty nodes bottom-up
|
||||
for i in reversed(range(len(tokens))):
|
||||
d_prev, d, t = path[i], path[i + 1], tokens[i]
|
||||
if len(d) > 0:
|
||||
break
|
||||
del d_prev[t]
|
||||
|
||||
def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
||||
"""
|
||||
Extract a cache entry for exclusive use.
|
||||
|
||||
If the entry has count > 1, deep copy and decrement.
|
||||
If count == 1, remove from cache entirely.
|
||||
"""
|
||||
cache_entry = self._get(model, tokens)
|
||||
if cache_entry.count == 1:
|
||||
self._delete(model, tokens)
|
||||
self._lru.remove((model, tokens))
|
||||
return cache_entry
|
||||
|
||||
cache_entry.count -= 1
|
||||
return CacheEntry(
|
||||
copy.deepcopy(cache_entry.prompt_cache),
|
||||
1,
|
||||
)
|
||||
|
||||
def fetch_nearest_cache(
|
||||
self, model, tokens: List[int]
|
||||
) -> Tuple[Optional[List[Any]], List[int]]:
|
||||
"""
|
||||
Fetch the nearest cache for the given token sequence.
|
||||
|
||||
Thread-safe. Returns (cache, remaining_tokens) where:
|
||||
- cache: The KV cache to use (or None if no cache found)
|
||||
- remaining_tokens: Tokens that still need to be processed
|
||||
|
||||
Args:
|
||||
model: Model identifier (used to namespace caches)
|
||||
tokens: The full token sequence for the prompt
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_cache, remaining_tokens)
|
||||
"""
|
||||
with self._lock:
|
||||
tokens_tuple = tuple(tokens)
|
||||
result = self._search(model, tokens)
|
||||
|
||||
# Exact match - extract and return
|
||||
if result.exact is not None:
|
||||
cache_entry = self._extract(result.model, result.exact)
|
||||
return cache_entry.prompt_cache, []
|
||||
|
||||
# Shorter prefix match - extract and return remaining
|
||||
if result.shorter is not None:
|
||||
cache_entry = self._extract(result.model, result.shorter)
|
||||
prefix_len = len(result.shorter)
|
||||
return cache_entry.prompt_cache, list(tokens[prefix_len:])
|
||||
|
||||
# Longer prefix match - try to trim if possible
|
||||
if result.longer is not None and self._can_trim_fn is not None:
|
||||
cache_entry = self._get(result.model, result.longer)
|
||||
if self._can_trim_fn(cache_entry.prompt_cache):
|
||||
# Deep copy and trim
|
||||
trimmed_cache = copy.deepcopy(cache_entry.prompt_cache)
|
||||
prefix = min(len(tokens) - 1, result.common_prefix)
|
||||
num_to_trim = len(result.longer) - prefix
|
||||
if self._trim_fn is not None:
|
||||
self._trim_fn(trimmed_cache, num_to_trim)
|
||||
return trimmed_cache, list(tokens[prefix:])
|
||||
|
||||
# No match found
|
||||
return None, list(tokens)
|
||||
|
||||
def insert_cache(
|
||||
self, model, tokens: List[int], prompt_cache: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
Insert a cache entry after generation completes.
|
||||
|
||||
Thread-safe. Handles LRU eviction if max_size is exceeded.
|
||||
|
||||
Args:
|
||||
model: Model identifier (used to namespace caches)
|
||||
tokens: The full token sequence (prompt + generated)
|
||||
prompt_cache: The KV cache to store
|
||||
"""
|
||||
with self._lock:
|
||||
tokens_tuple = tuple(tokens)
|
||||
|
||||
if model not in self._cache:
|
||||
self._cache[model] = {}
|
||||
current = self._cache[model]
|
||||
|
||||
# Build trie path
|
||||
for tok in tokens_tuple:
|
||||
if tok not in current:
|
||||
current[tok] = {}
|
||||
current = current[tok]
|
||||
|
||||
# Update or create entry
|
||||
if "cache" in current:
|
||||
current["cache"].count += 1
|
||||
self._lru.remove((model, tokens_tuple))
|
||||
else:
|
||||
current["cache"] = CacheEntry(prompt_cache, 1)
|
||||
|
||||
# Update LRU order
|
||||
self._lru.append((model, tokens_tuple))
|
||||
|
||||
# Evict if over capacity
|
||||
if len(self._lru) > self.max_size:
|
||||
evict_model, evict_tokens = self._lru.popleft()
|
||||
self._delete(evict_model, evict_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache entries. Thread-safe."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._lru.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of cache entries. Thread-safe."""
|
||||
with self._lock:
|
||||
return len(self._lru)
|
||||
2
backend/python/mlx-distributed/requirements-cpu.txt
Normal file
2
backend/python/mlx-distributed/requirements-cpu.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cpu]
|
||||
2
backend/python/mlx-distributed/requirements-cublas12.txt
Normal file
2
backend/python/mlx-distributed/requirements-cublas12.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cuda12]
|
||||
2
backend/python/mlx-distributed/requirements-cublas13.txt
Normal file
2
backend/python/mlx-distributed/requirements-cublas13.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cuda13]
|
||||
2
backend/python/mlx-distributed/requirements-l4t12.txt
Normal file
2
backend/python/mlx-distributed/requirements-l4t12.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cuda12]
|
||||
2
backend/python/mlx-distributed/requirements-l4t13.txt
Normal file
2
backend/python/mlx-distributed/requirements-l4t13.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cuda13]
|
||||
1
backend/python/mlx-distributed/requirements-mps.txt
Normal file
1
backend/python/mlx-distributed/requirements-mps.txt
Normal file
@@ -0,0 +1 @@
|
||||
mlx-lm
|
||||
4
backend/python/mlx-distributed/requirements.txt
Normal file
4
backend/python/mlx-distributed/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
11
backend/python/mlx-distributed/run.sh
Normal file
11
backend/python/mlx-distributed/run.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
136
backend/python/mlx-distributed/sharding.py
Normal file
136
backend/python/mlx-distributed/sharding.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Auto-parallelism for MLX distributed inference.
|
||||
|
||||
Provides pipeline parallelism (Ring backend) by wrapping model layers with
|
||||
distributed send/recv operations. Ported from exo's auto_parallel.py with
|
||||
simplifications for LocalAI's use case.
|
||||
"""
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PipelineFirstLayer(nn.Module):
|
||||
"""Wraps the first layer on each rank to receive from the previous rank."""
|
||||
|
||||
def __init__(self, original_layer, rank, group):
|
||||
super().__init__()
|
||||
dict.__setitem__(self, "_original_layer", original_layer)
|
||||
self.rank = rank
|
||||
self.group = group
|
||||
|
||||
@property
|
||||
def original_layer(self):
|
||||
return self["_original_layer"]
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self["_original_layer"], name)
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
if self.rank != 0:
|
||||
mx.eval(x)
|
||||
x = mx.distributed.recv_like(x, self.rank - 1, group=self.group)
|
||||
mx.eval(x)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
class PipelineLastLayer(nn.Module):
|
||||
"""Wraps the last layer on each rank to send to the next rank."""
|
||||
|
||||
def __init__(self, original_layer, rank, world_size, group):
|
||||
super().__init__()
|
||||
dict.__setitem__(self, "_original_layer", original_layer)
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.group = group
|
||||
|
||||
@property
|
||||
def original_layer(self):
|
||||
return self["_original_layer"]
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self["_original_layer"], name)
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
output = self.original_layer(x, *args, **kwargs)
|
||||
mx.eval(output)
|
||||
if self.rank != self.world_size - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.rank + 1) % self.world_size, group=self.group
|
||||
)
|
||||
mx.eval(output)
|
||||
# Gather output from all ranks so every rank has the final result
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
]
|
||||
mx.eval(output)
|
||||
return output
|
||||
|
||||
|
||||
def get_inner_model(model):
|
||||
"""Get the inner model (model.model or model.transformer)."""
|
||||
for attr in ("model", "transformer"):
|
||||
inner = getattr(model, attr, None)
|
||||
if isinstance(inner, nn.Module):
|
||||
# Some models have model.model (e.g. language_model.model)
|
||||
inner_inner = getattr(inner, "model", None)
|
||||
if isinstance(inner_inner, nn.Module):
|
||||
return inner_inner
|
||||
return inner
|
||||
raise ValueError("Model must have a 'model' or 'transformer' attribute")
|
||||
|
||||
|
||||
def get_layers(inner_model):
|
||||
"""Get the list of transformer layers."""
|
||||
for attr in ("layers", "h"):
|
||||
layers = getattr(inner_model, attr, None)
|
||||
if layers is not None:
|
||||
return layers
|
||||
raise ValueError("Model must have a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
def pipeline_auto_parallel(model, group, start_layer=None, end_layer=None):
|
||||
"""Apply pipeline parallelism to a model.
|
||||
|
||||
Each rank only keeps its slice of layers. The first layer receives from
|
||||
the previous rank, and the last layer sends to the next rank.
|
||||
|
||||
Args:
|
||||
model: The MLX model (must have model.layers or similar)
|
||||
group: The distributed group
|
||||
start_layer: First layer index for this rank (auto-computed if None)
|
||||
end_layer: Last layer index (exclusive) for this rank (auto-computed if None)
|
||||
"""
|
||||
rank = group.rank()
|
||||
world_size = group.size()
|
||||
|
||||
inner = get_inner_model(model)
|
||||
layers = list(get_layers(inner))
|
||||
total_layers = len(layers)
|
||||
|
||||
if start_layer is None or end_layer is None:
|
||||
layers_per_rank = total_layers // world_size
|
||||
remainder = total_layers % world_size
|
||||
start_layer = rank * layers_per_rank + min(rank, remainder)
|
||||
end_layer = start_layer + layers_per_rank + (1 if rank < remainder else 0)
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
for layer in layers:
|
||||
mx.eval(layer)
|
||||
|
||||
# Wrap first and last layers
|
||||
layers[0] = PipelineFirstLayer(layers[0], rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(layers[-1], rank, world_size, group=group)
|
||||
|
||||
# Replace layers on the inner model
|
||||
if hasattr(inner, "layers"):
|
||||
inner.layers = layers
|
||||
elif hasattr(inner, "h"):
|
||||
inner.h = layers
|
||||
|
||||
return model
|
||||
87
backend/python/mlx-distributed/test.py
Normal file
87
backend/python/mlx-distributed/test.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_text(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("text service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_sampling_params(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="The capital of France is",
|
||||
TopP=0.8,
|
||||
Tokens=50,
|
||||
Temperature=0.7,
|
||||
TopK=40,
|
||||
MinP=0.05,
|
||||
Seed=42,
|
||||
)
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("sampling params service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
12
backend/python/mlx-distributed/test.sh
Normal file
12
backend/python/mlx-distributed/test.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -84,19 +84,37 @@ func (a *Application) StartP2P() error {
|
||||
n = node
|
||||
}
|
||||
|
||||
// Attach a ServiceDiscoverer to the p2p node
|
||||
// Attach a ServiceDiscoverer to the p2p node for llama.cpp workers
|
||||
xlog.Info("Starting P2P server discovery...")
|
||||
if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node schema.NodeData) {
|
||||
if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID), func(serviceID string, node schema.NodeData) {
|
||||
var tunnelAddresses []string
|
||||
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) {
|
||||
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID)) {
|
||||
if v.IsOnline() {
|
||||
tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
|
||||
} else {
|
||||
xlog.Info("Node is offline", "node", v.ID)
|
||||
}
|
||||
}
|
||||
if a.applicationConfig.TunnelCallback != nil {
|
||||
a.applicationConfig.TunnelCallback(tunnelAddresses)
|
||||
if a.applicationConfig.LlamaCPPTunnelCallback != nil {
|
||||
a.applicationConfig.LlamaCPPTunnelCallback(tunnelAddresses)
|
||||
}
|
||||
}, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Attach a ServiceDiscoverer for MLX distributed workers
|
||||
xlog.Info("Starting MLX P2P worker discovery...")
|
||||
if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.MLXWorkerID), func(serviceID string, node schema.NodeData) {
|
||||
var tunnelAddresses []string
|
||||
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.MLXWorkerID)) {
|
||||
if v.IsOnline() {
|
||||
tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
|
||||
} else {
|
||||
xlog.Info("MLX node is offline", "node", v.ID)
|
||||
}
|
||||
}
|
||||
if a.applicationConfig.MLXTunnelCallback != nil {
|
||||
a.applicationConfig.MLXTunnelCallback(tunnelAddresses)
|
||||
}
|
||||
}, true); err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,8 +2,10 @@ package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -171,12 +173,18 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithMachineTag(r.MachineTag),
|
||||
config.WithAPIAddress(r.Address),
|
||||
config.WithAgentJobRetentionDays(r.AgentJobRetentionDays),
|
||||
config.WithTunnelCallback(func(tunnels []string) {
|
||||
config.WithLlamaCPPTunnelCallback(func(tunnels []string) {
|
||||
tunnelEnvVar := strings.Join(tunnels, ",")
|
||||
// TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable
|
||||
os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
|
||||
xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar)
|
||||
}),
|
||||
config.WithMLXTunnelCallback(func(tunnels []string) {
|
||||
hostfile := filepath.Join(os.TempDir(), "localai_mlx_hostfile.json")
|
||||
data, _ := json.Marshal(tunnels)
|
||||
os.WriteFile(hostfile, data, 0644)
|
||||
os.Setenv("MLX_DISTRIBUTED_HOSTFILE", hostfile)
|
||||
xlog.Debug("setting MLX_DISTRIBUTED_HOSTFILE", "value", hostfile, "tunnels", tunnels)
|
||||
}),
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
|
||||
@@ -8,6 +8,8 @@ type WorkerFlags struct {
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"`
|
||||
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
|
||||
P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"`
|
||||
P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"`
|
||||
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
|
||||
MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"`
|
||||
}
|
||||
|
||||
65
core/cli/worker/worker_mlx_common.go
Normal file
65
core/cli/worker/worker_mlx_common.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
const mlxDistributedGalleryName = "mlx-distributed"
|
||||
|
||||
// findMLXDistributedBackendPath finds or installs the mlx-distributed backend
|
||||
// and returns the directory containing run.sh.
|
||||
func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState) (string, error) {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backend, ok := backends.Get(mlxDistributedGalleryName)
|
||||
if !ok {
|
||||
ml := model.NewModelLoader(systemState)
|
||||
var gals []config.Gallery
|
||||
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
|
||||
xlog.Error("failed loading galleries", "error", err)
|
||||
return "", err
|
||||
}
|
||||
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, mlxDistributedGalleryName, nil, true); err != nil {
|
||||
xlog.Error("mlx-distributed backend not found, failed to install it", "error", err)
|
||||
return "", err
|
||||
}
|
||||
// Re-fetch after install
|
||||
backends, err = gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
backend, ok = backends.Get(mlxDistributedGalleryName)
|
||||
if !ok {
|
||||
return "", errors.New("mlx-distributed backend not found after install")
|
||||
}
|
||||
}
|
||||
|
||||
backendPath := filepath.Dir(backend.RunFile)
|
||||
if backendPath == "" {
|
||||
return "", errors.New("mlx-distributed backend not found, install it first")
|
||||
}
|
||||
return backendPath, nil
|
||||
}
|
||||
|
||||
// buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend.
|
||||
// backendPath is the directory containing run.sh (empty string to fall back to
|
||||
// running backend.py directly via python3).
|
||||
func buildMLXCommand(backendPath string, args ...string) *exec.Cmd {
|
||||
if backendPath != "" {
|
||||
return exec.Command(filepath.Join(backendPath, "run.sh"), args...)
|
||||
}
|
||||
return exec.Command("python3", append([]string{"backend.py"}, args...)...)
|
||||
}
|
||||
62
core/cli/worker/worker_mlx_distributed.go
Normal file
62
core/cli/worker/worker_mlx_distributed.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
type MLXDistributed struct {
|
||||
WorkerFlags `embed:""`
|
||||
Hostfile string `env:"MLX_DISTRIBUTED_HOSTFILE" required:"" help:"Path to hostfile JSON. Ring: array of 'ip:port' where entry i is rank i's listen address. JACCL: 2D matrix of RDMA device names."`
|
||||
Rank int `env:"MLX_RANK" required:"" help:"Rank of this process (0 = gRPC server + ring participant, >0 = worker only)"`
|
||||
Backend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend: 'ring' (TCP pipeline parallelism) or 'jaccl' (RDMA tensor parallelism)"`
|
||||
Addr string `env:"MLX_DISTRIBUTED_ADDR" default:"localhost:50051" help:"gRPC API listen address for LocalAI (rank 0 only, separate from ring communication)"`
|
||||
Coordinator string `env:"MLX_JACCL_COORDINATOR" default:"" help:"JACCL coordinator ip:port — rank 0's address where it accepts RDMA setup connections (all ranks must use the same value)"`
|
||||
}
|
||||
|
||||
func (r *MLXDistributed) Run(ctx *cliContext.Context) error {
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot find mlx-distributed backend: %w", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"--backend", r.Backend,
|
||||
"--hostfile", r.Hostfile,
|
||||
"--rank", fmt.Sprint(r.Rank),
|
||||
}
|
||||
|
||||
if r.Rank == 0 {
|
||||
args = append(args, "--addr", r.Addr)
|
||||
} else {
|
||||
args = append(args, "--worker")
|
||||
}
|
||||
|
||||
if r.Backend == "jaccl" && r.Coordinator != "" {
|
||||
args = append(args, "--coordinator", r.Coordinator)
|
||||
}
|
||||
|
||||
cmd := buildMLXCommand(backendPath, args...)
|
||||
runSh := cmd.Path
|
||||
|
||||
xlog.Info("Starting mlx-distributed", "rank", r.Rank, "backend", r.Backend, "hostfile", r.Hostfile)
|
||||
|
||||
return syscall.Exec(
|
||||
runSh,
|
||||
append([]string{runSh}, args...),
|
||||
os.Environ(),
|
||||
)
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
p = r.RunnerPort
|
||||
}
|
||||
|
||||
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -104,7 +104,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
97
core/cli/worker/worker_p2p_mlx.go
Normal file
97
core/cli/worker/worker_p2p_mlx.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/phayes/freeport"
|
||||
)
|
||||
|
||||
type P2PMLX struct {
|
||||
WorkerFlags `embed:""`
|
||||
Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"`
|
||||
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode" group:"p2p"`
|
||||
MLXListenPort string `env:"MLX_LISTEN_PORT" default:"5555" help:"Port for MLX distributed communication"`
|
||||
MLXBackend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend (ring or jaccl)"`
|
||||
}
|
||||
|
||||
func (r *P2PMLX) Run(ctx *cliContext.Context) error {
|
||||
if r.Token == "" {
|
||||
return fmt.Errorf("token is required")
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(r.BackendsPath),
|
||||
system.WithBackendSystemPath(r.BackendsSystemPath),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
port, err := freeport.GetFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.MLXListenPort != "" {
|
||||
fmt.Sscanf(r.MLXListenPort, "%d", &port)
|
||||
}
|
||||
|
||||
address := "127.0.0.1"
|
||||
|
||||
c, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState)
|
||||
if err != nil {
|
||||
xlog.Warn("Could not find mlx-distributed backend from gallery, will try backend.py directly", "error", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
hostfile := os.Getenv("MLX_DISTRIBUTED_HOSTFILE")
|
||||
if hostfile == "" {
|
||||
xlog.Info("Waiting for MLX_DISTRIBUTED_HOSTFILE to be set by P2P discovery...")
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
xlog.Info("Starting mlx-distributed worker", "address", address, "port", port, "hostfile", hostfile)
|
||||
|
||||
cmd := buildMLXCommand(backendPath,
|
||||
"--worker",
|
||||
"--backend", r.MLXBackend,
|
||||
"--hostfile", hostfile,
|
||||
"--rank", "0",
|
||||
)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdout = os.Stdout
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
xlog.Error("mlx-distributed worker exited", "error", err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.MLXWorkerID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
xlog.Info("MLX distributed worker registered on P2P network", "address", address, "port", port)
|
||||
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
<-c.Done()
|
||||
return nil
|
||||
}
|
||||
@@ -82,7 +82,8 @@ type ApplicationConfig struct {
|
||||
|
||||
APIAddress string
|
||||
|
||||
TunnelCallback func(tunnels []string)
|
||||
LlamaCPPTunnelCallback func(tunnels []string)
|
||||
MLXTunnelCallback func(tunnels []string)
|
||||
|
||||
DisableRuntimeSettings bool
|
||||
|
||||
@@ -457,9 +458,15 @@ func WithContextSize(ctxSize int) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithTunnelCallback(callback func(tunnels []string)) AppOption {
|
||||
func WithLlamaCPPTunnelCallback(callback func(tunnels []string)) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.TunnelCallback = callback
|
||||
o.LlamaCPPTunnelCallback = callback
|
||||
}
|
||||
}
|
||||
|
||||
func WithMLXTunnelCallback(callback func(tunnels []string)) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.MLXTunnelCallback = callback
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ func (s *DiscoveryServer) retrieveNetworkData(c context.Context, ledger *blockch
|
||||
for d := range data {
|
||||
toScanForWorkers := false
|
||||
cd := ClusterData{}
|
||||
isWorkerCluster := d == p2p.WorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.WorkerID))
|
||||
isWorkerCluster := d == p2p.LlamaCPPWorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.LlamaCPPWorkerID))
|
||||
isFederatedCluster := d == p2p.FederatedID || (strings.Contains(d, "_") && strings.Contains(d, p2p.FederatedID))
|
||||
switch {
|
||||
case isWorkerCluster:
|
||||
|
||||
@@ -15,8 +15,9 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// Render index
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, schema.P2PNodesResponse{
|
||||
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
|
||||
LlamaCPPNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)),
|
||||
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
|
||||
MLXNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,8 +103,9 @@ function StepNumber({ n, bg, color }) {
|
||||
export default function P2P() {
|
||||
const { addToast } = useOutletContext()
|
||||
const [workers, setWorkers] = useState([])
|
||||
const [mlxWorkers, setMlxWorkers] = useState([])
|
||||
const [federation, setFederation] = useState([])
|
||||
const [stats, setStats] = useState({ workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 } })
|
||||
const [stats, setStats] = useState({ llama_cpp_workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 }, mlx_workers: { online: 0, total: 0 } })
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [enabled, setEnabled] = useState(false)
|
||||
const [token, setToken] = useState('')
|
||||
@@ -129,7 +130,13 @@ export default function P2P() {
|
||||
if (p2pToken) {
|
||||
if (wRes.status === 'fulfilled') {
|
||||
const data = wRes.value
|
||||
setWorkers(data?.nodes || (Array.isArray(data) ? data : []))
|
||||
// Handle both old format ({nodes: [...]}) and new grouped format ({llama_cpp: {nodes: [...]}, mlx: {nodes: [...]}})
|
||||
if (data?.llama_cpp) {
|
||||
setWorkers(data.llama_cpp.nodes || [])
|
||||
setMlxWorkers(data.mlx?.nodes || [])
|
||||
} else {
|
||||
setWorkers(data?.nodes || (Array.isArray(data) ? data : []))
|
||||
}
|
||||
}
|
||||
if (fRes.status === 'fulfilled') {
|
||||
const data = fRes.value
|
||||
@@ -274,8 +281,10 @@ export default function P2P() {
|
||||
// ── P2P Enabled ──
|
||||
const fedOnline = stats.federated?.online ?? 0
|
||||
const fedTotal = stats.federated?.total ?? 0
|
||||
const wrkOnline = stats.workers?.online ?? 0
|
||||
const wrkTotal = stats.workers?.total ?? 0
|
||||
const llamaOnline = stats.llama_cpp_workers?.online ?? 0
|
||||
const llamaTotal = stats.llama_cpp_workers?.total ?? 0
|
||||
const mlxOnline = stats.mlx_workers?.online ?? 0
|
||||
const mlxTotal = stats.mlx_workers?.total ?? 0
|
||||
|
||||
return (
|
||||
<div className="page">
|
||||
@@ -401,7 +410,7 @@ export default function P2P() {
|
||||
fontSize: '0.75rem',
|
||||
color: activeTab === 'sharding' ? 'var(--color-accent)' : 'var(--color-text-muted)',
|
||||
}}>
|
||||
{wrkOnline}/{wrkTotal} workers
|
||||
{llamaOnline + mlxOnline}/{llamaTotal + mlxTotal} workers
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -562,6 +571,21 @@ export default function P2P() {
|
||||
borderRadius: 'var(--radius-lg)', overflow: 'hidden',
|
||||
}}>
|
||||
<div style={{ padding: 'var(--spacing-lg)', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{
|
||||
background: 'var(--color-accent-light)', border: '1px solid rgba(139,92,246,0.3)',
|
||||
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
fontSize: '0.8125rem', color: 'var(--color-text-secondary)', marginBottom: 'var(--spacing-md)',
|
||||
}}>
|
||||
<i className="fas fa-info-circle" style={{ color: 'var(--color-accent)', marginRight: 6 }} />
|
||||
<strong>Different from federation:</strong> Federation distributes whole requests across instances. Model sharding splits a single model across machines for joint inference.
|
||||
</div>
|
||||
|
||||
{/* ── llama.cpp RPC Workers Section ── */}
|
||||
<h3 style={{ fontSize: '1.125rem', fontWeight: 700, marginBottom: 'var(--spacing-sm)', display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)' }}>
|
||||
<i className="fas fa-microchip" style={{ color: 'var(--color-accent)' }} />
|
||||
llama.cpp RPC Workers
|
||||
</h3>
|
||||
|
||||
{/* Architecture diagram */}
|
||||
<div style={{
|
||||
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border-subtle)',
|
||||
@@ -604,25 +628,15 @@ export default function P2P() {
|
||||
</div>
|
||||
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)', textAlign: 'center', marginTop: 'var(--spacing-sm)', lineHeight: 1.5 }}>
|
||||
Model weights are <strong>split across RPC workers</strong>. Each worker holds a portion of the model layers in its memory (GPU or CPU).
|
||||
The LocalAI instance orchestrates inference by communicating with all workers via RPC.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style={{
|
||||
background: 'var(--color-accent-light)', border: '1px solid rgba(139,92,246,0.3)',
|
||||
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-sm) var(--spacing-md)',
|
||||
fontSize: '0.8125rem', color: 'var(--color-text-secondary)', marginBottom: 'var(--spacing-md)',
|
||||
}}>
|
||||
<i className="fas fa-info-circle" style={{ color: 'var(--color-accent)', marginRight: 6 }} />
|
||||
<strong>Different from federation:</strong> Federation distributes whole requests across instances. Model sharding splits a single model's weights across machines for joint inference. Currently only supported with <strong>llama.cpp</strong> based models.
|
||||
</div>
|
||||
|
||||
{/* Status + nodes */}
|
||||
{/* llama.cpp Status + nodes */}
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', marginBottom: 'var(--spacing-md)' }}>
|
||||
<h3 style={{ fontSize: '1rem', fontWeight: 700 }}>Connected Workers</h3>
|
||||
<h4 style={{ fontSize: '1rem', fontWeight: 600 }}>Connected Workers</h4>
|
||||
<div style={{ fontSize: '1.25rem', fontWeight: 700 }}>
|
||||
<span style={{ color: wrkOnline > 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{wrkOnline}</span>
|
||||
<span style={{ color: 'var(--color-text-secondary)', fontSize: '1rem' }}>/{wrkTotal}</span>
|
||||
<span style={{ color: llamaOnline > 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{llamaOnline}</span>
|
||||
<span style={{ color: 'var(--color-text-secondary)', fontSize: '1rem' }}>/{llamaTotal}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -633,7 +647,7 @@ export default function P2P() {
|
||||
borderRadius: 'var(--radius-lg)',
|
||||
}}>
|
||||
<i className="fas fa-puzzle-piece" style={{ fontSize: '2rem', color: 'var(--color-text-muted)', marginBottom: 'var(--spacing-sm)' }} />
|
||||
<p style={{ fontWeight: 500, color: 'var(--color-text-secondary)' }}>No workers available</p>
|
||||
<p style={{ fontWeight: 500, color: 'var(--color-text-secondary)' }}>No llama.cpp workers connected</p>
|
||||
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-muted)', marginTop: 'var(--spacing-xs)' }}>Start workers to see them here</p>
|
||||
</div>
|
||||
) : (
|
||||
@@ -645,17 +659,99 @@ export default function P2P() {
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Setup Guide */}
|
||||
{/* ── MLX Distributed Workers Section ── */}
|
||||
<div style={{ padding: 'var(--spacing-lg)', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<h3 style={{ fontSize: '1.125rem', fontWeight: 700, marginBottom: 'var(--spacing-sm)', display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)' }}>
|
||||
<i className="fas fa-apple-whole" style={{ color: 'rgba(245,158,11,1)' }} />
|
||||
MLX Distributed Workers
|
||||
</h3>
|
||||
|
||||
{/* MLX Architecture diagram */}
|
||||
<div style={{
|
||||
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border-subtle)',
|
||||
borderRadius: 'var(--radius-lg)', padding: 'var(--spacing-md)', marginBottom: 'var(--spacing-md)',
|
||||
}}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', gap: 'var(--spacing-md)', flexWrap: 'wrap' }}>
|
||||
<div style={{ textAlign: 'center' }}>
|
||||
<div style={{
|
||||
width: 48, height: 48, borderRadius: 'var(--radius-md)',
|
||||
background: 'var(--color-primary-light)', display: 'flex', alignItems: 'center', justifyContent: 'center',
|
||||
margin: '0 auto var(--spacing-xs)', border: '2px solid var(--color-primary)',
|
||||
}}>
|
||||
<i className="fas fa-server" style={{ color: 'var(--color-primary)', fontSize: '1rem' }} />
|
||||
</div>
|
||||
<div style={{ fontSize: '0.75rem', fontWeight: 600 }}>LocalAI</div>
|
||||
<div style={{ fontSize: '0.625rem', color: 'var(--color-text-muted)' }}>Rank 0</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: '2px' }}>
|
||||
<i className="fas fa-arrows-left-right" style={{ color: 'var(--color-text-muted)', fontSize: '0.875rem' }} />
|
||||
<span style={{ fontSize: '0.625rem', color: 'var(--color-text-muted)' }}>Ring / JACCL</span>
|
||||
</div>
|
||||
<div style={{ textAlign: 'center' }}>
|
||||
<div style={{ display: 'flex', gap: '4px', marginBottom: 'var(--spacing-xs)' }}>
|
||||
{['Layers 1-16', 'Layers 17-32'].map((label, i) => (
|
||||
<div key={i} style={{ textAlign: 'center' }}>
|
||||
<div style={{
|
||||
width: 64, height: 36, borderRadius: 'var(--radius-sm)',
|
||||
background: 'rgba(245,158,11,0.1)', display: 'flex', alignItems: 'center', justifyContent: 'center',
|
||||
border: '1px solid rgba(245,158,11,0.3)',
|
||||
}}>
|
||||
<i className="fas fa-microchip" style={{ color: 'rgba(245,158,11,1)', fontSize: '0.75rem' }} />
|
||||
</div>
|
||||
<div style={{ fontSize: '0.5625rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{label}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<div style={{ fontSize: '0.75rem', fontWeight: 600 }}>MLX Workers</div>
|
||||
<div style={{ fontSize: '0.625rem', color: 'var(--color-text-muted)' }}>Pipeline parallel</div>
|
||||
</div>
|
||||
</div>
|
||||
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)', textAlign: 'center', marginTop: 'var(--spacing-sm)', lineHeight: 1.5 }}>
|
||||
MLX distributed uses <strong>native Apple Silicon communication</strong>. All nodes execute model code simultaneously via pipeline or tensor parallelism.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* MLX Status + nodes */}
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', marginBottom: 'var(--spacing-md)' }}>
|
||||
<h4 style={{ fontSize: '1rem', fontWeight: 600 }}>Connected MLX Workers</h4>
|
||||
<div style={{ fontSize: '1.25rem', fontWeight: 700 }}>
|
||||
<span style={{ color: mlxOnline > 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{mlxOnline}</span>
|
||||
<span style={{ color: 'var(--color-text-secondary)', fontSize: '1rem' }}>/{mlxTotal}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{mlxWorkers.length === 0 ? (
|
||||
<div style={{
|
||||
textAlign: 'center', padding: 'var(--spacing-lg)',
|
||||
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border-subtle)',
|
||||
borderRadius: 'var(--radius-lg)',
|
||||
}}>
|
||||
<i className="fas fa-apple-whole" style={{ fontSize: '2rem', color: 'var(--color-text-muted)', marginBottom: 'var(--spacing-sm)' }} />
|
||||
<p style={{ fontWeight: 500, color: 'var(--color-text-secondary)' }}>No MLX workers connected</p>
|
||||
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-muted)', marginTop: 'var(--spacing-xs)' }}>Start MLX workers on Apple Silicon Macs</p>
|
||||
</div>
|
||||
) : (
|
||||
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(auto-fill, minmax(260px, 1fr))', gap: 'var(--spacing-md)' }}>
|
||||
{mlxWorkers.map((node, i) => (
|
||||
<NodeCard key={node.id || i} node={node} label={`MLX Rank ${i + 1}`} iconColor="rgba(245,158,11,1)" iconBg="rgba(245,158,11,0.1)" />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Setup Guides */}
|
||||
<div style={{ padding: 'var(--spacing-lg)' }}>
|
||||
<h3 style={{ fontSize: '1.125rem', fontWeight: 700, marginBottom: 'var(--spacing-md)' }}>
|
||||
<i className="fas fa-book" style={{ color: 'var(--color-accent)', marginRight: 'var(--spacing-sm)' }} />
|
||||
Start a llama.cpp RPC Worker
|
||||
Setup Workers
|
||||
</h3>
|
||||
|
||||
<div style={{
|
||||
background: 'var(--color-bg-primary)', borderRadius: 'var(--radius-lg)',
|
||||
border: '1px solid var(--color-border-subtle)', padding: 'var(--spacing-lg)',
|
||||
marginBottom: 'var(--spacing-md)',
|
||||
}}>
|
||||
<h4 style={{ fontSize: '1rem', fontWeight: 600, marginBottom: 'var(--spacing-sm)' }}>llama.cpp RPC Worker</h4>
|
||||
<p style={{ color: 'var(--color-text-secondary)', fontSize: '0.875rem', marginBottom: 'var(--spacing-sm)' }}>
|
||||
Each worker exposes its GPU/CPU memory as a shard for distributed model inference.
|
||||
</p>
|
||||
@@ -663,17 +759,24 @@ export default function P2P() {
|
||||
command={`docker run -ti --net host \\\n -e TOKEN="${token}" \\\n --name local-ai-worker \\\n localai/localai:latest-cpu worker p2p-llama-cpp-rpc`}
|
||||
addToast={addToast}
|
||||
/>
|
||||
<p style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem', marginTop: 'var(--spacing-sm)' }}>
|
||||
Run this on each machine you want to contribute as a shard. The worker will automatically join the network and advertise its resources.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<p style={{ color: 'var(--color-text-secondary)', fontSize: '0.875rem', marginTop: 'var(--spacing-lg)' }}>
|
||||
For GPU images and all available options, see the{' '}
|
||||
<a href="https://localai.io/basics/container/" target="_blank" rel="noopener noreferrer"
|
||||
style={{ color: 'var(--color-accent)' }}>Container images</a>
|
||||
{' '}and{' '}
|
||||
<a href="https://localai.io/features/distribute/#starting-workers" target="_blank" rel="noopener noreferrer"
|
||||
style={{ color: 'var(--color-accent)' }}>Worker</a> docs.
|
||||
<div style={{
|
||||
background: 'var(--color-bg-primary)', borderRadius: 'var(--radius-lg)',
|
||||
border: '1px solid rgba(245,158,11,0.3)', padding: 'var(--spacing-lg)',
|
||||
}}>
|
||||
<h4 style={{ fontSize: '1rem', fontWeight: 600, marginBottom: 'var(--spacing-sm)' }}>MLX Distributed Worker</h4>
|
||||
<p style={{ color: 'var(--color-text-secondary)', fontSize: '0.875rem', marginBottom: 'var(--spacing-sm)' }}>
|
||||
Run on Apple Silicon Macs to participate in distributed MLX inference via pipeline parallelism.
|
||||
</p>
|
||||
<CommandBlock
|
||||
command={`docker run -ti --net host \\\n -e TOKEN="${token}" \\\n --name local-ai-mlx-worker \\\n localai/localai:latest-metal-darwin-arm64 worker p2p-mlx`}
|
||||
addToast={addToast}
|
||||
/>
|
||||
<p style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem', marginTop: 'var(--spacing-sm)' }}>
|
||||
For more information, see the{' '}
|
||||
<a href="https://localai.io/features/mlx-distributed/" target="_blank" rel="noopener noreferrer"
|
||||
style={{ color: 'rgba(245,158,11,1)' }}>MLX Distributed</a> docs.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1041,11 +1041,12 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
|
||||
// P2P APIs
|
||||
app.GET("/api/p2p/workers", func(c echo.Context) error {
|
||||
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
|
||||
llamaNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID))
|
||||
mlxNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID))
|
||||
|
||||
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
nodesJSON = append(nodesJSON, map[string]interface{}{
|
||||
llamaJSON := make([]map[string]any, 0, len(llamaNodes))
|
||||
for _, n := range llamaNodes {
|
||||
llamaJSON = append(llamaJSON, map[string]any{
|
||||
"name": n.Name,
|
||||
"id": n.ID,
|
||||
"tunnelAddress": n.TunnelAddress,
|
||||
@@ -1055,8 +1056,27 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"nodes": nodesJSON,
|
||||
mlxJSON := make([]map[string]any, 0, len(mlxNodes))
|
||||
for _, n := range mlxNodes {
|
||||
mlxJSON = append(mlxJSON, map[string]any{
|
||||
"name": n.Name,
|
||||
"id": n.ID,
|
||||
"tunnelAddress": n.TunnelAddress,
|
||||
"serviceID": n.ServiceID,
|
||||
"lastSeen": n.LastSeen,
|
||||
"isOnline": n.IsOnline(),
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(200, map[string]any{
|
||||
"llama_cpp": map[string]any{
|
||||
"nodes": llamaJSON,
|
||||
},
|
||||
"mlx": map[string]any{
|
||||
"nodes": mlxJSON,
|
||||
},
|
||||
// Keep backward-compatible "nodes" key with llama.cpp workers
|
||||
"nodes": llamaJSON,
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1081,13 +1101,14 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
})
|
||||
|
||||
app.GET("/api/p2p/stats", func(c echo.Context) error {
|
||||
workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
|
||||
llamaCPPNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID))
|
||||
federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
|
||||
mlxWorkerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID))
|
||||
|
||||
workersOnline := 0
|
||||
for _, n := range workerNodes {
|
||||
llamaCPPOnline := 0
|
||||
for _, n := range llamaCPPNodes {
|
||||
if n.IsOnline() {
|
||||
workersOnline++
|
||||
llamaCPPOnline++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1098,15 +1119,26 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(200, map[string]interface{}{
|
||||
"workers": map[string]interface{}{
|
||||
"online": workersOnline,
|
||||
"total": len(workerNodes),
|
||||
mlxWorkersOnline := 0
|
||||
for _, n := range mlxWorkerNodes {
|
||||
if n.IsOnline() {
|
||||
mlxWorkersOnline++
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(200, map[string]any{
|
||||
"llama_cpp_workers": map[string]any{
|
||||
"online": llamaCPPOnline,
|
||||
"total": len(llamaCPPNodes),
|
||||
},
|
||||
"federated": map[string]interface{}{
|
||||
"federated": map[string]any{
|
||||
"online": federatedOnline,
|
||||
"total": len(federatedNodes),
|
||||
},
|
||||
"mlx_workers": map[string]any{
|
||||
"online": mlxWorkersOnline,
|
||||
"total": len(mlxWorkerNodes),
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -262,8 +262,8 @@
|
||||
</div>
|
||||
<div class="text-right">
|
||||
<div class="text-2xl font-bold">
|
||||
<span :class="stats.workers.online > 0 ? 'text-[var(--color-success)]' : 'text-[var(--color-error)]'" x-text="stats.workers.online"></span>
|
||||
<span class="text-[var(--color-text-secondary)] text-xl">/<span x-text="stats.workers.total"></span></span>
|
||||
<span :class="stats.llama_cpp_workers.online > 0 ? 'text-[var(--color-success)]' : 'text-[var(--color-error)]'" x-text="stats.llama_cpp_workers.online"></span>
|
||||
<span class="text-[var(--color-text-secondary)] text-xl">/<span x-text="stats.llama_cpp_workers.total"></span></span>
|
||||
</div>
|
||||
<p class="text-[var(--color-accent)] text-sm">workers</p>
|
||||
</div>
|
||||
@@ -469,8 +469,8 @@ docker run -ti --net host -e TOKEN="<span class="token">{{.P2PToken}}</span>" --
|
||||
<div class="text-right">
|
||||
<div class="text-sm text-[var(--color-text-secondary)] mb-1">Active Workers</div>
|
||||
<div class="text-3xl font-bold">
|
||||
<span :class="stats.workers.online > 0 ? 'text-[var(--color-accent)]' : 'text-[var(--color-error)]'" x-text="stats.workers.online"></span>
|
||||
<span class="text-[var(--color-text-secondary)] text-xl">/<span x-text="stats.workers.total"></span></span>
|
||||
<span :class="stats.llama_cpp_workers.online > 0 ? 'text-[var(--color-accent)]' : 'text-[var(--color-error)]'" x-text="stats.llama_cpp_workers.online"></span>
|
||||
<span class="text-[var(--color-text-secondary)] text-xl">/<span x-text="stats.llama_cpp_workers.total"></span></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -657,7 +657,7 @@ function p2pNetwork() {
|
||||
workerNodes: [],
|
||||
federationNodes: [],
|
||||
stats: {
|
||||
workers: { online: 0, total: 0 },
|
||||
llama_cpp_workers: { online: 0, total: 0 },
|
||||
federated: { online: 0, total: 0 }
|
||||
},
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultServicesID = "services"
|
||||
WorkerID = "worker"
|
||||
defaultServicesID = "services"
|
||||
LlamaCPPWorkerID = "worker"
|
||||
MLXWorkerID = "mlx_worker"
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
@@ -136,8 +136,9 @@ func (d NodeData) IsOnline() bool {
|
||||
}
|
||||
|
||||
type P2PNodesResponse struct {
|
||||
Nodes []NodeData `json:"nodes" yaml:"nodes"`
|
||||
LlamaCPPNodes []NodeData `json:"llama_cpp_nodes" yaml:"llama_cpp_nodes"`
|
||||
FederatedNodes []NodeData `json:"federated_nodes" yaml:"federated_nodes"`
|
||||
MLXNodes []NodeData `json:"mlx_nodes" yaml:"mlx_nodes"`
|
||||
}
|
||||
|
||||
type SysInfoModel struct {
|
||||
|
||||
212
docs/content/features/mlx-distributed.md
Normal file
212
docs/content/features/mlx-distributed.md
Normal file
@@ -0,0 +1,212 @@
|
||||
+++
|
||||
disableToc = false
|
||||
title = "(experimental) MLX Distributed Inference"
|
||||
weight = 18
|
||||
url = '/features/mlx-distributed/'
|
||||
+++
|
||||
|
||||
MLX distributed inference allows you to split large language models across multiple Apple Silicon Macs (or other devices) for joint inference. Unlike federation (which distributes whole requests), MLX distributed splits a single model's layers across machines so they all participate in every forward pass.
|
||||
|
||||
## How It Works
|
||||
|
||||
MLX distributed uses **pipeline parallelism** via the Ring backend: each node holds a slice of the model's layers. During inference, activations flow from rank 0 through each subsequent rank in a pipeline. The last rank gathers the final output.
|
||||
|
||||
For high-bandwidth setups (e.g., Thunderbolt-connected Macs), **JACCL** (tensor parallelism via RDMA) is also supported, where each rank holds all layers but with sharded weights.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Two or more machines with MLX installed (Apple Silicon recommended)
|
||||
- Network connectivity between all nodes (TCP for Ring, RDMA/Thunderbolt for JACCL)
|
||||
- Same model accessible on all nodes (e.g., from Hugging Face cache)
|
||||
|
||||
## Quick Start with P2P
|
||||
|
||||
The simplest way to use MLX distributed is with LocalAI's P2P auto-discovery.
|
||||
|
||||
### 1. Start LocalAI with P2P
|
||||
|
||||
```bash
|
||||
docker run -ti --net host \
|
||||
--name local-ai \
|
||||
localai/localai:latest-metal-darwin-arm64 run --p2p
|
||||
```
|
||||
|
||||
This generates a network token. Copy it for the next step.
|
||||
|
||||
### 2. Start MLX Workers
|
||||
|
||||
On each additional Mac:
|
||||
|
||||
```bash
|
||||
docker run -ti --net host \
|
||||
-e TOKEN="<your-token>" \
|
||||
--name local-ai-mlx-worker \
|
||||
localai/localai:latest-metal-darwin-arm64 worker p2p-mlx
|
||||
```
|
||||
|
||||
Workers auto-register on the P2P network. The LocalAI server discovers them and generates a hostfile for MLX distributed.
|
||||
|
||||
### 3. Use the Model
|
||||
|
||||
Load any MLX-compatible model. The `mlx-distributed` backend will automatically shard it across all available ranks:
|
||||
|
||||
```yaml
|
||||
name: llama-distributed
|
||||
backend: mlx-distributed
|
||||
parameters:
|
||||
model: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||
```
|
||||
|
||||
## Model Configuration
|
||||
|
||||
The `mlx-distributed` backend is started automatically by LocalAI like any other backend. You configure distributed inference through the model YAML file using the `options` field:
|
||||
|
||||
### Ring Backend (TCP)
|
||||
|
||||
```yaml
|
||||
name: llama-distributed
|
||||
backend: mlx-distributed
|
||||
parameters:
|
||||
model: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||
options:
|
||||
- "hostfile:/path/to/hosts.json"
|
||||
- "distributed_backend:ring"
|
||||
```
|
||||
|
||||
The **hostfile** is a JSON array where entry `i` is the `"ip:port"` that **rank `i` listens on** for ring communication. All ranks must use the same hostfile so they know how to reach each other.
|
||||
|
||||
**Example:** Two Macs — Mac A (`192.168.1.10`) and Mac B (`192.168.1.11`):
|
||||
|
||||
```json
|
||||
["192.168.1.10:5555", "192.168.1.11:5555"]
|
||||
```
|
||||
|
||||
- Entry 0 (`192.168.1.10:5555`) — the address rank 0 (Mac A) listens on for ring communication
|
||||
- Entry 1 (`192.168.1.11:5555`) — the address rank 1 (Mac B) listens on for ring communication
|
||||
|
||||
Port 5555 is arbitrary — use any available port, but it must be open in your firewall.
|
||||
|
||||
### JACCL Backend (RDMA/Thunderbolt)
|
||||
|
||||
```yaml
|
||||
name: llama-distributed
|
||||
backend: mlx-distributed
|
||||
parameters:
|
||||
model: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||
options:
|
||||
- "hostfile:/path/to/devices.json"
|
||||
- "distributed_backend:jaccl"
|
||||
```
|
||||
|
||||
The **device matrix** is a JSON 2D array describing the RDMA device name between each pair of ranks. The diagonal is `null` (a rank doesn't talk to itself):
|
||||
|
||||
```json
|
||||
[
|
||||
[null, "rdma_thunderbolt0"],
|
||||
["rdma_thunderbolt0", null]
|
||||
]
|
||||
```
|
||||
|
||||
JACCL requires a **coordinator** — a TCP service that helps all ranks establish RDMA connections. Rank 0 (the LocalAI machine) is always the coordinator. Workers are told the coordinator address via their `--coordinator` CLI flag (see [Starting Workers](#jaccl-workers) below).
|
||||
|
||||
### Without hostfile (single-node)
|
||||
|
||||
If no `hostfile` option is set and no `MLX_DISTRIBUTED_HOSTFILE` environment variable exists, the backend runs as a regular single-node MLX backend. This is useful for testing or when you don't need distributed inference.
|
||||
|
||||
### Available Options
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `hostfile` | Path to the hostfile JSON. Ring: array of `"ip:port"`. JACCL: device matrix. |
|
||||
| `distributed_backend` | `ring` (default) or `jaccl` |
|
||||
| `trust_remote_code` | Allow trust_remote_code for the tokenizer |
|
||||
| `max_tokens` | Override default max generation tokens |
|
||||
| `temperature` / `temp` | Sampling temperature |
|
||||
| `top_p` | Top-p sampling |
|
||||
|
||||
These can also be set via environment variables (`MLX_DISTRIBUTED_HOSTFILE`, `MLX_DISTRIBUTED_BACKEND`) which are used as fallbacks when the model options don't specify them.
|
||||
|
||||
## Starting Workers
|
||||
|
||||
LocalAI starts the rank 0 process (gRPC server) automatically when the model is loaded. But you still need to start **worker processes** (ranks 1, 2, ...) on the other machines. These workers participate in every forward pass but don't serve any API — they wait for commands from rank 0.
|
||||
|
||||
### Ring Workers
|
||||
|
||||
On each worker machine, start a worker with the same hostfile:
|
||||
|
||||
```bash
|
||||
local-ai worker mlx-distributed --hostfile hosts.json --rank 1
|
||||
```
|
||||
|
||||
The `--rank` must match the worker's position in the hostfile. For example, if `hosts.json` is `["192.168.1.10:5555", "192.168.1.11:5555", "192.168.1.12:5555"]`, then:
|
||||
- Rank 0: started automatically by LocalAI on `192.168.1.10`
|
||||
- Rank 1: `local-ai worker mlx-distributed --hostfile hosts.json --rank 1` on `192.168.1.11`
|
||||
- Rank 2: `local-ai worker mlx-distributed --hostfile hosts.json --rank 2` on `192.168.1.12`
|
||||
|
||||
### JACCL Workers
|
||||
|
||||
```bash
|
||||
local-ai worker mlx-distributed \
|
||||
--hostfile devices.json \
|
||||
--rank 1 \
|
||||
--backend jaccl \
|
||||
--coordinator 192.168.1.10:5555
|
||||
```
|
||||
|
||||
The `--coordinator` address is the IP of the machine running LocalAI (rank 0) with any available port. Rank 0 binds the coordinator service there; workers connect to it to establish RDMA connections.
|
||||
|
||||
### Worker Startup Order
|
||||
|
||||
Start workers **before** loading the model in LocalAI. When LocalAI sends the LoadModel request, rank 0 initializes `mx.distributed` which tries to connect to all ranks listed in the hostfile. If workers aren't running yet, it will time out.
|
||||
|
||||
## Advanced: Manual Rank 0
|
||||
|
||||
For advanced use cases, you can also run rank 0 manually as an external gRPC backend instead of letting LocalAI start it automatically:
|
||||
|
||||
```bash
|
||||
# On Mac A: start rank 0 manually
|
||||
local-ai worker mlx-distributed --hostfile hosts.json --rank 0 --addr 192.168.1.10:50051
|
||||
|
||||
# On Mac B: start rank 1
|
||||
local-ai worker mlx-distributed --hostfile hosts.json --rank 1
|
||||
|
||||
# On any machine: start LocalAI pointing at rank 0
|
||||
local-ai run --external-grpc-backends "mlx-distributed:192.168.1.10:50051"
|
||||
```
|
||||
|
||||
Then use a model config with `backend: mlx-distributed` (no need for `hostfile` in options since rank 0 already has it from CLI args).
|
||||
|
||||
## CLI Reference
|
||||
|
||||
### `worker mlx-distributed`
|
||||
|
||||
Starts a worker or manual rank 0 process.
|
||||
|
||||
| Flag | Env | Default | Description |
|
||||
|------|-----|---------|-------------|
|
||||
| `--hostfile` | `MLX_DISTRIBUTED_HOSTFILE` | *(required)* | Path to hostfile JSON. Ring: array of `"ip:port"` where entry `i` is rank `i`'s listen address. JACCL: device matrix of RDMA device names. |
|
||||
| `--rank` | `MLX_RANK` | *(required)* | Rank of this process (0 = gRPC server + ring participant, >0 = worker only) |
|
||||
| `--backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | `ring` (TCP pipeline parallelism) or `jaccl` (RDMA tensor parallelism) |
|
||||
| `--addr` | `MLX_DISTRIBUTED_ADDR` | `localhost:50051` | gRPC API listen address (rank 0 only, for LocalAI or external access) |
|
||||
| `--coordinator` | `MLX_JACCL_COORDINATOR` | | JACCL coordinator `ip:port` — rank 0's address for RDMA setup (all ranks must use the same value) |
|
||||
|
||||
### `worker p2p-mlx`
|
||||
|
||||
P2P mode — auto-discovers peers and generates hostfile.
|
||||
|
||||
| Flag | Env | Default | Description |
|
||||
|------|-----|---------|-------------|
|
||||
| `--token` | `TOKEN` | *(required)* | P2P network token |
|
||||
| `--mlx-listen-port` | `MLX_LISTEN_PORT` | `5555` | Port for MLX communication |
|
||||
| `--mlx-backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | Backend type: `ring` or `jaccl` |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **All ranks download the model independently.** Each node auto-downloads from Hugging Face on first use via `mlx_lm.load()`. On rank 0 (started by LocalAI), models are downloaded to LocalAI's model directory (`HF_HOME` is set automatically). On workers, models go to the default HF cache (`~/.cache/huggingface/hub`) unless you set `HF_HOME` yourself.
|
||||
- **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. Start workers before loading the model.
|
||||
- **Rank assignment:** In P2P mode, rank 0 is always the LocalAI server. Worker ranks are assigned by sorting node IDs.
|
||||
- **Performance:** Pipeline parallelism adds latency proportional to the number of ranks. For best results, use the fewest ranks needed to fit your model in memory.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
The MLX distributed auto-parallel sharding implementation is based on [exo](https://github.com/exo-explore/exo).
|
||||
Reference in New Issue
Block a user