feat(mlx-distributed): add new MLX-distributed backend (#8801)

* feat(mlx-distributed): add new MLX-distributed backend

Add new MLX distributed backend with support for both TCP and RDMA for
model sharding.

This implementation ties in the discovery implementation already in
place, and re-uses the same P2P mechanism for the TCP MLX-distributed
inferencing.

The Auto-parallel implementation is inspired by Exo's
ones (who have been added to acknowledgement for the great work!)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* expose a CLI to facilitate backend starting

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: make manual rank0 configurable via model configs

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add missing features from mlx backend

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Apply suggestion from @mudler

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
Ettore Di Giacinto
2026-03-09 17:29:32 +01:00
committed by GitHub
parent 734b6d391f
commit a026277ab9
36 changed files with 2016 additions and 73 deletions

View File

@@ -157,6 +157,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: ''
cuda-major-version: ""
cuda-minor-version: ""
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-cpu-mlx-distributed'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'true'
backend: "mlx-distributed"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
# CUDA 12 builds
- build-type: 'cublas'
cuda-major-version: "12"
@@ -470,6 +483,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-12-mlx-distributed'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "mlx-distributed"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "8"
@@ -822,6 +848,19 @@ jobs:
backend: "mlx-audio"
dockerfile: "./backend/Dockerfile.python"
context: "./"
- build-type: 'l4t'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/arm64'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-cuda-13-arm64-mlx-distributed'
runs-on: 'ubuntu-24.04-arm'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
ubuntu-version: '2404'
backend: "mlx-distributed"
dockerfile: "./backend/Dockerfile.python"
context: "./"
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -926,6 +965,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-gpu-nvidia-cuda-13-mlx-distributed'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:24.04"
skip-drivers: 'false'
backend: "mlx-distributed"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2404'
- build-type: 'cublas'
cuda-major-version: "13"
cuda-minor-version: "0"
@@ -1423,6 +1475,19 @@ jobs:
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2204'
- build-type: 'l4t'
cuda-major-version: "12"
cuda-minor-version: "0"
platforms: 'linux/arm64'
tag-latest: 'auto'
tag-suffix: '-nvidia-l4t-mlx-distributed'
runs-on: 'ubuntu-24.04-arm'
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
skip-drivers: 'true'
backend: "mlx-distributed"
dockerfile: "./backend/Dockerfile.python"
context: "./"
ubuntu-version: '2204'
# SYCL additional backends
- build-type: 'intel'
cuda-major-version: ""
@@ -2016,6 +2081,9 @@ jobs:
- backend: "mlx-audio"
tag-suffix: "-metal-darwin-arm64-mlx-audio"
build-type: "mps"
- backend: "mlx-distributed"
tag-suffix: "-metal-darwin-arm64-mlx-distributed"
build-type: "mps"
- backend: "stablediffusion-ggml"
tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml"
build-type: "metal"

View File

@@ -1,5 +1,5 @@
# Disable parallel execution for backend builds
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral
GOCMD=go
GOTEST=$(GOCMD) test
@@ -451,6 +451,10 @@ backends/mlx-audio:
BACKEND=mlx-audio $(MAKE) build-darwin-python-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)"
backends/mlx-distributed:
BACKEND=mlx-distributed $(MAKE) build-darwin-python-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-distributed.tar)"
backends/stablediffusion-ggml-darwin:
BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend
./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)"
@@ -495,6 +499,7 @@ BACKEND_NEMO = nemo|python|.|false|true
BACKEND_VOXCPM = voxcpm|python|.|false|true
BACKEND_WHISPERX = whisperx|python|.|false|true
BACKEND_ACE_STEP = ace-step|python|.|false|true
BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true
# Helper function to build docker image for a backend
# Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG)
@@ -548,12 +553,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_NEMO)))
$(eval $(call generate-docker-build-target,$(BACKEND_VOXCPM)))
$(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX)))
$(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP)))
$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED)))
# Pattern rule for docker-save targets
docker-save-%: backend-images
docker save local-ai-backend:$* -o backend-images/$*.tar
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral docker-build-mlx-distributed
########################################################
### Mock Backend for E2E Tests

View File

@@ -482,6 +482,7 @@ LocalAI couldn't have been built without the help of great software already avai
- https://github.com/EdVince/Stable-Diffusion-NCNN
- https://github.com/ggerganov/whisper.cpp
- https://github.com/rhasspy/piper
- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation
## 🤗 Contributors

View File

@@ -259,6 +259,31 @@
nvidia-l4t: "nvidia-l4t-mlx-audio"
nvidia-l4t-cuda-12: "nvidia-l4t-mlx-audio"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-audio"
- &mlx-distributed
name: "mlx-distributed"
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-distributed"
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
urls:
- https://github.com/ml-explore/mlx-lm
mirrors:
- localai/localai-backends:latest-metal-darwin-arm64-mlx-distributed
license: MIT
description: |
Run distributed LLM inference with MLX across multiple Apple Silicon Macs
tags:
- text-to-text
- LLM
- MLX
- distributed
capabilities:
default: "cpu-mlx-distributed"
nvidia: "cuda12-mlx-distributed"
metal: "metal-mlx-distributed"
nvidia-cuda-12: "cuda12-mlx-distributed"
nvidia-cuda-13: "cuda13-mlx-distributed"
nvidia-l4t: "nvidia-l4t-mlx-distributed"
nvidia-l4t-cuda-12: "nvidia-l4t-mlx-distributed"
nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-distributed"
- &rerankers
name: "rerankers"
alias: "rerankers"
@@ -791,6 +816,11 @@
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-mlx-audio
- !!merge <<: *mlx-distributed
name: "mlx-distributed-development"
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-distributed"
mirrors:
- localai/localai-backends:master-metal-darwin-arm64-mlx-distributed
## mlx
- !!merge <<: *mlx
name: "cpu-mlx"
@@ -944,6 +974,57 @@
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio"
mirrors:
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio
## mlx-distributed
- !!merge <<: *mlx-distributed
name: "cpu-mlx-distributed"
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-mlx-distributed"
mirrors:
- localai/localai-backends:latest-cpu-mlx-distributed
- !!merge <<: *mlx-distributed
name: "cpu-mlx-distributed-development"
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-mlx-distributed"
mirrors:
- localai/localai-backends:master-cpu-mlx-distributed
- !!merge <<: *mlx-distributed
name: "cuda12-mlx-distributed"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed"
mirrors:
- localai/localai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed
- !!merge <<: *mlx-distributed
name: "cuda12-mlx-distributed-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-mlx-distributed"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-12-mlx-distributed
- !!merge <<: *mlx-distributed
name: "cuda13-mlx-distributed"
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed"
mirrors:
- localai/localai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed
- !!merge <<: *mlx-distributed
name: "cuda13-mlx-distributed-development"
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-mlx-distributed"
mirrors:
- localai/localai-backends:master-gpu-nvidia-cuda-13-mlx-distributed
- !!merge <<: *mlx-distributed
name: "nvidia-l4t-mlx-distributed"
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-mlx-distributed"
mirrors:
- localai/localai-backends:latest-nvidia-l4t-mlx-distributed
- !!merge <<: *mlx-distributed
name: "nvidia-l4t-mlx-distributed-development"
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-mlx-distributed"
mirrors:
- localai/localai-backends:master-nvidia-l4t-mlx-distributed
- !!merge <<: *mlx-distributed
name: "cuda13-nvidia-l4t-arm64-mlx-distributed"
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed"
mirrors:
- localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed
- !!merge <<: *mlx-distributed
name: "cuda13-nvidia-l4t-arm64-mlx-distributed-development"
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed"
mirrors:
- localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed
- !!merge <<: *kitten-tts
name: "kitten-tts-development"
uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts"

View File

@@ -0,0 +1,23 @@
.PHONY: mlx-distributed
mlx-distributed:
bash install.sh
.PHONY: run
run:
@echo "Running mlx-distributed..."
bash run.sh
@echo "mlx-distributed run."
.PHONY: test
test:
@echo "Testing mlx-distributed..."
bash test.sh
@echo "mlx-distributed tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -0,0 +1,509 @@
#!/usr/bin/env python3
"""
MLX Distributed Inference Backend for LocalAI.
Two startup modes:
1. Server mode (started by LocalAI automatically):
run.sh --addr localhost:50051
Distributed config comes from LoadModel options or env vars.
2. Worker mode (started by CLI for remote ranks):
run.sh --worker --hostfile hosts.json --rank 1 --backend ring
Enters a loop waiting for commands from rank 0.
"""
import asyncio
from concurrent import futures
import argparse
import json
import os
import signal
import sys
import tempfile
from typing import List
import grpc
import backend_pb2
import backend_pb2_grpc
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None):
"""Initialize MLX distributed runtime.
Ring: MLX_HOSTFILE points to a JSON array of "ip:port" strings. Each rank
binds to its own entry (hostfile[rank]) and connects to neighbors for the
ring pipeline.
JACCL: MLX_IBV_DEVICES points to a JSON 2D matrix of RDMA device names.
MLX_JACCL_COORDINATOR is rank 0's ip:port where it runs a TCP service that
helps all ranks establish RDMA connections.
"""
import mlx.core as mx
if backend == "ring":
os.environ["MLX_HOSTFILE"] = hostfile
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
return mx.distributed.init(backend="ring", strict=True)
elif backend == "jaccl":
os.environ["MLX_IBV_DEVICES"] = hostfile
os.environ["MLX_RANK"] = str(rank)
if coordinator:
os.environ["MLX_JACCL_COORDINATOR"] = coordinator
return mx.distributed.init(backend="jaccl", strict=True)
else:
raise ValueError(f"Unknown backend: {backend}")
def is_float(s):
try:
float(s)
return True
except ValueError:
return False
def is_int(s):
try:
int(s)
return True
except ValueError:
return False
def parse_options(options):
"""Parse key:value option strings into a dict."""
result = {}
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
result[key] = value
return result
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""gRPC servicer for distributed MLX inference (runs on rank 0).
When started by LocalAI (server mode), distributed init happens at
LoadModel time using config from model options or environment variables.
"""
def __init__(self):
self.group = None
self.dist_backend = None
self.model = None
self.tokenizer = None
self.coordinator = None
self.options = {}
self.lru_cache = None
self.model_key = None
self.max_kv_size = None
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
async def LoadModel(self, request, context):
try:
import mlx.core as mx
from mlx_lm import load
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr)
self.options = parse_options(request.Options)
print(f"Options: {self.options}", file=sys.stderr)
# Get distributed config from model options, falling back to env vars.
# If neither is set, run as single-node (no distributed).
hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", ""))
dist_backend = str(self.options.get("distributed_backend",
os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring")))
# JACCL coordinator: rank 0 reads from env (set by CLI --coordinator).
# Not in model options — rank 0 is the coordinator, workers get
# the address via their own --coordinator CLI flag.
jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "")
if hostfile:
from coordinator import DistributedCoordinator, CMD_LOAD_MODEL
from sharding import pipeline_auto_parallel
print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr)
self.dist_backend = dist_backend
self.group = mlx_distributed_init(
rank=0,
hostfile=hostfile,
backend=dist_backend,
coordinator=jaccl_coordinator or None,
)
self.coordinator = DistributedCoordinator(self.group)
self.coordinator.broadcast_command(CMD_LOAD_MODEL)
self.coordinator.broadcast_model_name(request.Model)
else:
print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr)
# Build tokenizer config from request and options
tokenizer_config = {}
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
tokenizer_config["trust_remote_code"] = True
# Token overrides from options
for key in ["eos_token", "pad_token", "bos_token", "unk_token",
"sep_token", "cls_token", "mask_token"]:
if key in self.options:
tokenizer_config[key] = self.options[key]
if tokenizer_config:
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
else:
self.model, self.tokenizer = load(request.Model)
if self.group is not None:
from sharding import pipeline_auto_parallel
self.model = pipeline_auto_parallel(self.model, self.group)
print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr)
else:
# Single-node: set up prompt cache for efficient generation
from mlx_cache import ThreadSafeLRUPromptCache
max_cache_entries = self.options.get("max_cache_entries", 10)
self.max_kv_size = self.options.get("max_kv_size", None)
self.model_key = request.Model
self.lru_cache = ThreadSafeLRUPromptCache(
max_size=max_cache_entries,
can_trim_fn=can_trim_prompt_cache,
trim_fn=trim_prompt_cache,
)
print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr)
except Exception as err:
print(f"[Rank 0] Error loading model: {err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Error loading model: {err}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
async def Predict(self, request, context):
prompt_cache = None
cache_key = None
try:
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_sampler
prompt_text = self._prepare_prompt(request)
tokens = self._get_tokens_from_prompt(prompt_text)
if self.coordinator:
from coordinator import CMD_GENERATE
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
self.coordinator.broadcast_tokens(tokens)
max_tokens, sampler_params = self._build_generation_params(request)
if self.coordinator:
gen_params = self.coordinator.broadcast_generation_params(
max_tokens=max_tokens,
temperature=sampler_params.get('temp', 0.6),
top_p=sampler_params.get('top_p', 1.0),
)
max_tokens = gen_params["max_tokens"]
sampler = make_sampler(**sampler_params)
# Use prompt cache in single-node mode
gen_kwargs = {}
if self.lru_cache is not None:
from mlx_lm.models.cache import make_prompt_cache
cache_key = list(tokens)
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
self.model_key, cache_key
)
if prompt_cache is None:
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
remaining_tokens = cache_key
gen_kwargs['prompt_cache'] = prompt_cache
tokens = remaining_tokens if remaining_tokens else cache_key
generated = []
for response in stream_generate(
self.model,
self.tokenizer,
prompt=tokens,
max_tokens=max_tokens,
sampler=sampler,
**gen_kwargs,
):
generated.append(response.text)
if cache_key is not None:
cache_key.append(response.token)
if self.lru_cache is not None and cache_key is not None:
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8'))
except Exception as e:
print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Generation failed: {str(e)}")
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
async def PredictStream(self, request, context):
prompt_cache = None
cache_key = None
try:
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_sampler
prompt_text = self._prepare_prompt(request)
tokens = self._get_tokens_from_prompt(prompt_text)
if self.coordinator:
from coordinator import CMD_GENERATE
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
self.coordinator.broadcast_tokens(tokens)
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
if self.coordinator:
gen_params = self.coordinator.broadcast_generation_params(
max_tokens=max_tokens,
temperature=sampler_params.get('temp', 0.6),
top_p=sampler_params.get('top_p', 1.0),
)
max_tokens = gen_params["max_tokens"]
sampler = make_sampler(**sampler_params)
# Use prompt cache in single-node mode
gen_kwargs = {}
if self.lru_cache is not None:
from mlx_lm.models.cache import make_prompt_cache
cache_key = list(tokens)
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
self.model_key, cache_key
)
if prompt_cache is None:
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
remaining_tokens = cache_key
gen_kwargs['prompt_cache'] = prompt_cache
tokens = remaining_tokens if remaining_tokens else cache_key
for response in stream_generate(
self.model,
self.tokenizer,
prompt=tokens,
max_tokens=max_tokens,
sampler=sampler,
**gen_kwargs,
):
if cache_key is not None:
cache_key.append(response.token)
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
except Exception as e:
print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Streaming failed: {str(e)}")
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
finally:
if self.lru_cache is not None and prompt_cache is not None and cache_key is not None:
try:
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
except Exception as e:
print(f"Error inserting cache: {e}", file=sys.stderr)
def Embedding(self, request, context):
print("Embeddings not supported in MLX distributed backend", file=sys.stderr)
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Embeddings are not supported in the MLX distributed backend.")
return backend_pb2.EmbeddingResult()
def _prepare_prompt(self, request):
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages]
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return request.Prompt
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
tokens = self.tokenizer.encode(prompt_text)
if hasattr(tokens, 'tolist'):
return tokens.tolist()
return list(tokens)
def _build_generation_params(self, request, default_max_tokens=200):
import mlx.core as mx
max_tokens = getattr(request, 'Tokens', default_max_tokens)
if max_tokens == 0:
max_tokens = default_max_tokens
temp = getattr(request, 'Temperature', 0.0)
if temp == 0.0:
temp = 0.6
top_p = getattr(request, 'TopP', 0.0)
if top_p == 0.0:
top_p = 1.0
sampler_params = {
'temp': temp,
'top_p': top_p,
'min_p': getattr(request, 'MinP', 0.0),
'top_k': getattr(request, 'TopK', 0),
'xtc_threshold': 0.0,
'xtc_probability': 0.0,
}
seed = getattr(request, 'Seed', 0)
if seed != 0:
mx.random.seed(seed)
if hasattr(self, 'options'):
if 'max_tokens' in self.options:
max_tokens = self.options['max_tokens']
option_mapping = {
'temp': 'temp',
'temperature': 'temp',
'top_p': 'top_p',
'min_p': 'min_p',
'top_k': 'top_k',
'xtc_threshold': 'xtc_threshold',
'xtc_probability': 'xtc_probability',
}
for opt_key, param_key in option_mapping.items():
if opt_key in self.options:
sampler_params[param_key] = self.options[opt_key]
if 'seed' in self.options:
mx.random.seed(self.options['seed'])
# XTC special tokens
xtc_special_tokens = []
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
xtc_special_tokens = [self.tokenizer.eos_token_id]
try:
newline_tokens = self.tokenizer.encode("\n")
xtc_special_tokens.extend(newline_tokens)
except:
pass
sampler_params['xtc_special_tokens'] = xtc_special_tokens
return max_tokens, sampler_params
def run_worker(group):
"""Worker loop for ranks > 0. Waits for commands from rank 0."""
from mlx_lm import load, stream_generate
from mlx_lm.sample_utils import make_sampler
from coordinator import DistributedCoordinator, CMD_LOAD_MODEL, CMD_GENERATE, CMD_SHUTDOWN
from sharding import pipeline_auto_parallel
import mlx.core as mx
coordinator = DistributedCoordinator(group)
model = None
tokenizer = None
print(f"[Rank {group.rank()}] Worker started, waiting for commands...", file=sys.stderr)
while True:
cmd, payload_size = coordinator.wait_for_command()
if cmd == CMD_LOAD_MODEL:
model_name = coordinator.broadcast_model_name()
print(f"[Rank {group.rank()}] Loading model: {model_name}", file=sys.stderr)
model, tokenizer = load(model_name)
model = pipeline_auto_parallel(model, group)
print(f"[Rank {group.rank()}] Model loaded and sharded", file=sys.stderr)
elif cmd == CMD_GENERATE:
if model is None:
print(f"[Rank {group.rank()}] No model loaded, skipping generate", file=sys.stderr)
continue
token_count = coordinator.broadcast_token_count(payload_size)
tokens_array = coordinator.broadcast_tokens([0] * token_count)
tokens = tokens_array.tolist()
gen_params = coordinator.broadcast_generation_params()
sampler = make_sampler(
temp=gen_params["temperature"],
top_p=gen_params["top_p"],
)
for _ in stream_generate(
model, tokenizer,
prompt=tokens,
max_tokens=gen_params["max_tokens"],
sampler=sampler,
):
pass
elif cmd == CMD_SHUTDOWN:
print(f"[Rank {group.rank()}] Shutting down", file=sys.stderr)
break
async def serve(address):
server = grpc.aio.server(
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
await server.start()
print(f"[Rank 0] gRPC server listening on {address}", file=sys.stderr)
await server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MLX Distributed Backend")
parser.add_argument("--addr", default="localhost:50051",
help="gRPC listen address (used by LocalAI to send requests)")
parser.add_argument("--worker", action="store_true",
help="Run in worker mode (for remote ranks started by CLI)")
parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"],
help="ring = TCP pipeline parallelism, jaccl = RDMA tensor parallelism")
parser.add_argument("--hostfile", default=None,
help="Path to hostfile JSON (required for --worker mode)")
parser.add_argument("--rank", type=int, default=0,
help="Rank of this process (0 = server, >0 = worker)")
parser.add_argument("--coordinator", default=None,
help="JACCL coordinator ip:port (jaccl backend only)")
args = parser.parse_args()
if args.worker:
if not args.hostfile:
print("Error: --hostfile is required in worker mode", file=sys.stderr)
sys.exit(1)
group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator)
run_worker(group)
else:
# Server mode: started by LocalAI with just --addr.
# Distributed init deferred to LoadModel (reads config from model options/env vars).
asyncio.run(serve(args.addr))

View File

@@ -0,0 +1,104 @@
"""
Distributed coordination using MLX distributed primitives.
Rank 0 broadcasts commands and tokens to all ranks via all_sum/all_gather.
Worker ranks wait in a loop for commands from rank 0.
"""
import json
import struct
import mlx.core as mx
CMD_IDLE = 0
CMD_GENERATE = 1
CMD_LOAD_MODEL = 2
CMD_SHUTDOWN = -1
class DistributedCoordinator:
def __init__(self, group):
self.group = group
self.rank = group.rank()
self.world_size = group.size()
def broadcast_command(self, cmd, payload_size=0):
"""Rank 0 broadcasts a command to all ranks.
Uses all_sum with only rank 0 providing non-zero values so every
rank receives the same command array.
"""
if self.rank == 0:
cmd_array = mx.array([cmd, payload_size], dtype=mx.int32)
else:
cmd_array = mx.zeros((2,), dtype=mx.int32)
result = mx.distributed.all_sum(cmd_array, group=self.group)
mx.eval(result)
return int(result[0].item()), int(result[1].item())
def broadcast_tokens(self, tokens):
"""Broadcast input token ids from rank 0 to all ranks.
Rank 0 provides the real token array; other ranks provide zeros of the
same shape. ``all_sum`` ensures every rank ends up with identical data.
"""
if self.rank == 0:
token_array = mx.array(tokens, dtype=mx.int32)
else:
token_array = mx.zeros((len(tokens),), dtype=mx.int32)
result = mx.distributed.all_sum(token_array, group=self.group)
mx.eval(result)
return result
def broadcast_token_count(self, count):
"""Broadcast the number of tokens so workers can prepare a buffer."""
if self.rank == 0:
count_array = mx.array([count], dtype=mx.int32)
else:
count_array = mx.zeros((1,), dtype=mx.int32)
result = mx.distributed.all_sum(count_array, group=self.group)
mx.eval(result)
return int(result[0].item())
def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0):
"""Broadcast generation parameters from rank 0."""
if self.rank == 0:
params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32)
else:
params = mx.zeros((3,), dtype=mx.float32)
result = mx.distributed.all_sum(params, group=self.group)
mx.eval(result)
return {
"max_tokens": int(result[0].item()),
"temperature": float(result[1].item()),
"top_p": float(result[2].item()),
}
def wait_for_command(self):
"""Worker ranks block here until rank 0 broadcasts a command."""
return self.broadcast_command(CMD_IDLE, 0)
def broadcast_model_name(self, model_name=""):
"""Broadcast model name string from rank 0 to all ranks.
Encodes the model name as int32 codepoints so it can travel via
all_sum.
"""
if self.rank == 0:
encoded = [ord(c) for c in model_name]
# First broadcast the length
length = self.broadcast_token_count(len(encoded))
if length > 0:
name_array = mx.array(encoded, dtype=mx.int32)
result = mx.distributed.all_sum(name_array, group=self.group)
mx.eval(result)
return model_name
return ""
else:
length = self.broadcast_token_count(0)
if length > 0:
name_array = mx.zeros((length,), dtype=mx.int32)
result = mx.distributed.all_sum(name_array, group=self.group)
mx.eval(result)
return "".join(chr(int(c.item())) for c in result)
return ""

View File

@@ -0,0 +1,15 @@
#!/bin/bash
set -e
USE_PIP=true
PYTHON_VERSION=""
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
installRequirements

View File

@@ -0,0 +1,266 @@
"""
Thread-safe LRU prompt cache for MLX-based backends.
Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.)
with thread-safety additions for LocalAI's gRPC backend.
Usage:
from mlx_cache import ThreadSafeLRUPromptCache
# In LoadModel:
self.lru_cache = ThreadSafeLRUPromptCache(max_size=10)
# In Predict/PredictStream:
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens)
# ... generate ...
self.lru_cache.insert_cache(model_key, tokens, prompt_cache)
"""
import copy
import threading
from collections import deque
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple
@dataclass
class CacheEntry:
"""A cache entry with reference counting."""
prompt_cache: List[Any]
count: int
@dataclass
class SearchResult:
"""Result of searching the cache trie."""
model: Any
exact: Optional[List[int]]
shorter: Optional[List[int]]
longer: Optional[List[int]]
common_prefix: int
class ThreadSafeLRUPromptCache:
"""
Thread-safe LRU cache with prefix matching for prompt KV caches.
This cache stores KV caches keyed by token sequences and supports:
- Exact match: Return the cache for the exact token sequence
- Shorter prefix match: Return a cache for a prefix of the tokens
- Longer prefix match: If a longer sequence is cached and can be trimmed
- LRU eviction: When max_size is exceeded, evict least recently used
Thread safety is provided via a threading.Lock that protects all
cache operations.
Args:
max_size: Maximum number of cache entries (default: 10)
can_trim_fn: Optional function to check if a cache can be trimmed
trim_fn: Optional function to trim a cache
"""
def __init__(
self,
max_size: int = 10,
can_trim_fn: Optional[Any] = None,
trim_fn: Optional[Any] = None,
):
self.max_size = max_size
self._cache = {}
self._lru = deque()
self._lock = threading.Lock()
# Optional trim functions (for longer prefix reuse)
self._can_trim_fn = can_trim_fn
self._trim_fn = trim_fn
def _search(self, model, tokens: List[int]) -> SearchResult:
"""
Search the cache for a prompt cache. Return exact or close match.
The cache is organized as a trie where each node is keyed by a token.
This allows efficient prefix matching.
"""
if model not in self._cache:
return SearchResult(model, None, None, None, 0)
current = self._cache[model]
last_cache_index = -1
index = 0
# Traverse the trie following the token sequence
while index < len(tokens) and tokens[index] in current:
current = current[tokens[index]]
if "cache" in current:
last_cache_index = index
index += 1
# Exact match - no need to search for longer or shorter caches
if last_cache_index == len(tokens) - 1:
return SearchResult(model, tuple(tokens), None, None, 0)
# Find the shorter cache (a prefix that has a cache)
# Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior.
# Single-token prefixes are not matched, which allows longer cached
# sequences to be preferred for trimming. This is acceptable because
# real prompts with chat templates are always many tokens.
shorter = None
if last_cache_index > 0:
shorter = tuple(tokens[: last_cache_index + 1])
# Check for caches that are longer than our token sequence
longer = None
common_prefix = index
if index > 0 and last_cache_index <= 0:
best = None
stack = [(current, [])]
while stack:
current, extra = stack.pop()
if "cache" in current:
if best is None or len(extra) < len(best):
best = extra
else:
for tok in current:
stack.append((current[tok], extra + [tok]))
if best is not None:
longer = tuple(tokens[:index] + best)
return SearchResult(model, None, shorter, longer, common_prefix)
def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
"""Get a cache entry by traversing the trie."""
current = self._cache[model]
for tok in tokens:
current = current[tok]
return current["cache"]
def _delete(self, model, tokens: Tuple[int, ...]) -> None:
"""Delete a cache entry and clean up empty trie nodes."""
path = [self._cache[model]]
for tok in tokens:
path.append(path[-1][tok])
del path[-1]["cache"]
# Clean up empty nodes bottom-up
for i in reversed(range(len(tokens))):
d_prev, d, t = path[i], path[i + 1], tokens[i]
if len(d) > 0:
break
del d_prev[t]
def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
"""
Extract a cache entry for exclusive use.
If the entry has count > 1, deep copy and decrement.
If count == 1, remove from cache entirely.
"""
cache_entry = self._get(model, tokens)
if cache_entry.count == 1:
self._delete(model, tokens)
self._lru.remove((model, tokens))
return cache_entry
cache_entry.count -= 1
return CacheEntry(
copy.deepcopy(cache_entry.prompt_cache),
1,
)
def fetch_nearest_cache(
self, model, tokens: List[int]
) -> Tuple[Optional[List[Any]], List[int]]:
"""
Fetch the nearest cache for the given token sequence.
Thread-safe. Returns (cache, remaining_tokens) where:
- cache: The KV cache to use (or None if no cache found)
- remaining_tokens: Tokens that still need to be processed
Args:
model: Model identifier (used to namespace caches)
tokens: The full token sequence for the prompt
Returns:
Tuple of (prompt_cache, remaining_tokens)
"""
with self._lock:
tokens_tuple = tuple(tokens)
result = self._search(model, tokens)
# Exact match - extract and return
if result.exact is not None:
cache_entry = self._extract(result.model, result.exact)
return cache_entry.prompt_cache, []
# Shorter prefix match - extract and return remaining
if result.shorter is not None:
cache_entry = self._extract(result.model, result.shorter)
prefix_len = len(result.shorter)
return cache_entry.prompt_cache, list(tokens[prefix_len:])
# Longer prefix match - try to trim if possible
if result.longer is not None and self._can_trim_fn is not None:
cache_entry = self._get(result.model, result.longer)
if self._can_trim_fn(cache_entry.prompt_cache):
# Deep copy and trim
trimmed_cache = copy.deepcopy(cache_entry.prompt_cache)
prefix = min(len(tokens) - 1, result.common_prefix)
num_to_trim = len(result.longer) - prefix
if self._trim_fn is not None:
self._trim_fn(trimmed_cache, num_to_trim)
return trimmed_cache, list(tokens[prefix:])
# No match found
return None, list(tokens)
def insert_cache(
self, model, tokens: List[int], prompt_cache: List[Any]
) -> None:
"""
Insert a cache entry after generation completes.
Thread-safe. Handles LRU eviction if max_size is exceeded.
Args:
model: Model identifier (used to namespace caches)
tokens: The full token sequence (prompt + generated)
prompt_cache: The KV cache to store
"""
with self._lock:
tokens_tuple = tuple(tokens)
if model not in self._cache:
self._cache[model] = {}
current = self._cache[model]
# Build trie path
for tok in tokens_tuple:
if tok not in current:
current[tok] = {}
current = current[tok]
# Update or create entry
if "cache" in current:
current["cache"].count += 1
self._lru.remove((model, tokens_tuple))
else:
current["cache"] = CacheEntry(prompt_cache, 1)
# Update LRU order
self._lru.append((model, tokens_tuple))
# Evict if over capacity
if len(self._lru) > self.max_size:
evict_model, evict_tokens = self._lru.popleft()
self._delete(evict_model, evict_tokens)
def clear(self) -> None:
"""Clear all cache entries. Thread-safe."""
with self._lock:
self._cache.clear()
self._lru.clear()
def __len__(self) -> int:
"""Return the number of cache entries. Thread-safe."""
with self._lock:
return len(self._lru)

View File

@@ -0,0 +1,2 @@
mlx-lm
mlx[cpu]

View File

@@ -0,0 +1,2 @@
mlx-lm
mlx[cuda12]

View File

@@ -0,0 +1,2 @@
mlx-lm
mlx[cuda13]

View File

@@ -0,0 +1,2 @@
mlx-lm
mlx[cuda12]

View File

@@ -0,0 +1,2 @@
mlx-lm
mlx[cuda13]

View File

@@ -0,0 +1 @@
mlx-lm

View File

@@ -0,0 +1,4 @@
grpcio==1.71.0
protobuf
certifi
setuptools

View File

@@ -0,0 +1,11 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
startBackend $@

View File

@@ -0,0 +1,136 @@
"""
Auto-parallelism for MLX distributed inference.
Provides pipeline parallelism (Ring backend) by wrapping model layers with
distributed send/recv operations. Ported from exo's auto_parallel.py with
simplifications for LocalAI's use case.
"""
import mlx.core as mx
import mlx.nn as nn
class PipelineFirstLayer(nn.Module):
"""Wraps the first layer on each rank to receive from the previous rank."""
def __init__(self, original_layer, rank, group):
super().__init__()
dict.__setitem__(self, "_original_layer", original_layer)
self.rank = rank
self.group = group
@property
def original_layer(self):
return self["_original_layer"]
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self["_original_layer"], name)
def __call__(self, x, *args, **kwargs):
if self.rank != 0:
mx.eval(x)
x = mx.distributed.recv_like(x, self.rank - 1, group=self.group)
mx.eval(x)
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(nn.Module):
"""Wraps the last layer on each rank to send to the next rank."""
def __init__(self, original_layer, rank, world_size, group):
super().__init__()
dict.__setitem__(self, "_original_layer", original_layer)
self.rank = rank
self.world_size = world_size
self.group = group
@property
def original_layer(self):
return self["_original_layer"]
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self["_original_layer"], name)
def __call__(self, x, *args, **kwargs):
output = self.original_layer(x, *args, **kwargs)
mx.eval(output)
if self.rank != self.world_size - 1:
output = mx.distributed.send(
output, (self.rank + 1) % self.world_size, group=self.group
)
mx.eval(output)
# Gather output from all ranks so every rank has the final result
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
]
mx.eval(output)
return output
def get_inner_model(model):
"""Get the inner model (model.model or model.transformer)."""
for attr in ("model", "transformer"):
inner = getattr(model, attr, None)
if isinstance(inner, nn.Module):
# Some models have model.model (e.g. language_model.model)
inner_inner = getattr(inner, "model", None)
if isinstance(inner_inner, nn.Module):
return inner_inner
return inner
raise ValueError("Model must have a 'model' or 'transformer' attribute")
def get_layers(inner_model):
"""Get the list of transformer layers."""
for attr in ("layers", "h"):
layers = getattr(inner_model, attr, None)
if layers is not None:
return layers
raise ValueError("Model must have a 'layers' or 'h' attribute")
def pipeline_auto_parallel(model, group, start_layer=None, end_layer=None):
"""Apply pipeline parallelism to a model.
Each rank only keeps its slice of layers. The first layer receives from
the previous rank, and the last layer sends to the next rank.
Args:
model: The MLX model (must have model.layers or similar)
group: The distributed group
start_layer: First layer index for this rank (auto-computed if None)
end_layer: Last layer index (exclusive) for this rank (auto-computed if None)
"""
rank = group.rank()
world_size = group.size()
inner = get_inner_model(model)
layers = list(get_layers(inner))
total_layers = len(layers)
if start_layer is None or end_layer is None:
layers_per_rank = total_layers // world_size
remainder = total_layers % world_size
start_layer = rank * layers_per_rank + min(rank, remainder)
end_layer = start_layer + layers_per_rank + (1 if rank < remainder else 0)
layers = layers[start_layer:end_layer]
for layer in layers:
mx.eval(layer)
# Wrap first and last layers
layers[0] = PipelineFirstLayer(layers[0], rank, group=group)
layers[-1] = PipelineLastLayer(layers[-1], rank, world_size, group=group)
# Replace layers on the inner model
if hasattr(inner, "layers"):
inner.layers = layers
elif hasattr(inner, "h"):
inner.h = layers
return model

View File

@@ -0,0 +1,87 @@
import unittest
import subprocess
import time
import grpc
import backend_pb2
import backend_pb2_grpc
class TestBackendServicer(unittest.TestCase):
def setUp(self):
self.service = subprocess.Popen(
["python", "backend.py", "--addr", "localhost:50051"]
)
time.sleep(10)
def tearDown(self) -> None:
self.service.terminate()
self.service.wait()
def test_server_startup(self):
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_load_model(self):
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully")
except Exception as err:
print(err)
self.fail("LoadModel service failed")
finally:
self.tearDown()
def test_text(self):
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
resp = stub.Predict(req)
self.assertIsNotNone(resp.message)
except Exception as err:
print(err)
self.fail("text service failed")
finally:
self.tearDown()
def test_sampling_params(self):
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
self.assertTrue(response.success)
req = backend_pb2.PredictOptions(
Prompt="The capital of France is",
TopP=0.8,
Tokens=50,
Temperature=0.7,
TopK=40,
MinP=0.05,
Seed=42,
)
resp = stub.Predict(req)
self.assertIsNotNone(resp.message)
except Exception as err:
print(err)
self.fail("sampling params service failed")
finally:
self.tearDown()

View File

@@ -0,0 +1,12 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests

View File

@@ -84,19 +84,37 @@ func (a *Application) StartP2P() error {
n = node
}
// Attach a ServiceDiscoverer to the p2p node
// Attach a ServiceDiscoverer to the p2p node for llama.cpp workers
xlog.Info("Starting P2P server discovery...")
if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node schema.NodeData) {
if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID), func(serviceID string, node schema.NodeData) {
var tunnelAddresses []string
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) {
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID)) {
if v.IsOnline() {
tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
} else {
xlog.Info("Node is offline", "node", v.ID)
}
}
if a.applicationConfig.TunnelCallback != nil {
a.applicationConfig.TunnelCallback(tunnelAddresses)
if a.applicationConfig.LlamaCPPTunnelCallback != nil {
a.applicationConfig.LlamaCPPTunnelCallback(tunnelAddresses)
}
}, true); err != nil {
return err
}
// Attach a ServiceDiscoverer for MLX distributed workers
xlog.Info("Starting MLX P2P worker discovery...")
if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.MLXWorkerID), func(serviceID string, node schema.NodeData) {
var tunnelAddresses []string
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.MLXWorkerID)) {
if v.IsOnline() {
tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
} else {
xlog.Info("MLX node is offline", "node", v.ID)
}
}
if a.applicationConfig.MLXTunnelCallback != nil {
a.applicationConfig.MLXTunnelCallback(tunnelAddresses)
}
}, true); err != nil {
return err

View File

@@ -2,8 +2,10 @@ package cli
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
@@ -171,12 +173,18 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithMachineTag(r.MachineTag),
config.WithAPIAddress(r.Address),
config.WithAgentJobRetentionDays(r.AgentJobRetentionDays),
config.WithTunnelCallback(func(tunnels []string) {
config.WithLlamaCPPTunnelCallback(func(tunnels []string) {
tunnelEnvVar := strings.Join(tunnels, ",")
// TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable
os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar)
}),
config.WithMLXTunnelCallback(func(tunnels []string) {
hostfile := filepath.Join(os.TempDir(), "localai_mlx_hostfile.json")
data, _ := json.Marshal(tunnels)
os.WriteFile(hostfile, data, 0644)
os.Setenv("MLX_DISTRIBUTED_HOSTFILE", hostfile)
xlog.Debug("setting MLX_DISTRIBUTED_HOSTFILE", "value", hostfile, "tunnels", tunnels)
}),
}
if r.DisableMetricsEndpoint {

View File

@@ -8,6 +8,8 @@ type WorkerFlags struct {
}
type Worker struct {
P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"`
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"`
P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"`
LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"`
MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"`
}

View File

@@ -0,0 +1,65 @@
package worker
import (
"context"
"encoding/json"
"errors"
"os/exec"
"path/filepath"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/xlog"
)
const mlxDistributedGalleryName = "mlx-distributed"
// findMLXDistributedBackendPath finds or installs the mlx-distributed backend
// and returns the directory containing run.sh.
func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState) (string, error) {
backends, err := gallery.ListSystemBackends(systemState)
if err != nil {
return "", err
}
backend, ok := backends.Get(mlxDistributedGalleryName)
if !ok {
ml := model.NewModelLoader(systemState)
var gals []config.Gallery
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
xlog.Error("failed loading galleries", "error", err)
return "", err
}
if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, mlxDistributedGalleryName, nil, true); err != nil {
xlog.Error("mlx-distributed backend not found, failed to install it", "error", err)
return "", err
}
// Re-fetch after install
backends, err = gallery.ListSystemBackends(systemState)
if err != nil {
return "", err
}
backend, ok = backends.Get(mlxDistributedGalleryName)
if !ok {
return "", errors.New("mlx-distributed backend not found after install")
}
}
backendPath := filepath.Dir(backend.RunFile)
if backendPath == "" {
return "", errors.New("mlx-distributed backend not found, install it first")
}
return backendPath, nil
}
// buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend.
// backendPath is the directory containing run.sh (empty string to fall back to
// running backend.py directly via python3).
func buildMLXCommand(backendPath string, args ...string) *exec.Cmd {
if backendPath != "" {
return exec.Command(filepath.Join(backendPath, "run.sh"), args...)
}
return exec.Command("python3", append([]string{"backend.py"}, args...)...)
}

View File

@@ -0,0 +1,62 @@
package worker
import (
"fmt"
"os"
"syscall"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/xlog"
)
type MLXDistributed struct {
WorkerFlags `embed:""`
Hostfile string `env:"MLX_DISTRIBUTED_HOSTFILE" required:"" help:"Path to hostfile JSON. Ring: array of 'ip:port' where entry i is rank i's listen address. JACCL: 2D matrix of RDMA device names."`
Rank int `env:"MLX_RANK" required:"" help:"Rank of this process (0 = gRPC server + ring participant, >0 = worker only)"`
Backend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend: 'ring' (TCP pipeline parallelism) or 'jaccl' (RDMA tensor parallelism)"`
Addr string `env:"MLX_DISTRIBUTED_ADDR" default:"localhost:50051" help:"gRPC API listen address for LocalAI (rank 0 only, separate from ring communication)"`
Coordinator string `env:"MLX_JACCL_COORDINATOR" default:"" help:"JACCL coordinator ip:port — rank 0's address where it accepts RDMA setup connections (all ranks must use the same value)"`
}
func (r *MLXDistributed) Run(ctx *cliContext.Context) error {
systemState, err := system.GetSystemState(
system.WithBackendPath(r.BackendsPath),
system.WithBackendSystemPath(r.BackendsSystemPath),
)
if err != nil {
return err
}
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState)
if err != nil {
return fmt.Errorf("cannot find mlx-distributed backend: %w", err)
}
args := []string{
"--backend", r.Backend,
"--hostfile", r.Hostfile,
"--rank", fmt.Sprint(r.Rank),
}
if r.Rank == 0 {
args = append(args, "--addr", r.Addr)
} else {
args = append(args, "--worker")
}
if r.Backend == "jaccl" && r.Coordinator != "" {
args = append(args, "--coordinator", r.Coordinator)
}
cmd := buildMLXCommand(backendPath, args...)
runSh := cmd.Path
xlog.Info("Starting mlx-distributed", "rank", r.Rank, "backend", r.Backend, "hostfile", r.Hostfile)
return syscall.Exec(
runSh,
append([]string{runSh}, args...),
os.Environ(),
)
}

View File

@@ -62,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
p = r.RunnerPort
}
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID))
if err != nil {
return err
}
@@ -104,7 +104,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
}
}()
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID))
if err != nil {
return err
}

View File

@@ -0,0 +1,97 @@
package worker
import (
"context"
"fmt"
"os"
"time"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/pkg/signals"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/xlog"
"github.com/phayes/freeport"
)
type P2PMLX struct {
WorkerFlags `embed:""`
Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"`
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode" group:"p2p"`
MLXListenPort string `env:"MLX_LISTEN_PORT" default:"5555" help:"Port for MLX distributed communication"`
MLXBackend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend (ring or jaccl)"`
}
func (r *P2PMLX) Run(ctx *cliContext.Context) error {
if r.Token == "" {
return fmt.Errorf("token is required")
}
systemState, err := system.GetSystemState(
system.WithBackendPath(r.BackendsPath),
system.WithBackendSystemPath(r.BackendsSystemPath),
)
if err != nil {
return err
}
port, err := freeport.GetFreePort()
if err != nil {
return err
}
if r.MLXListenPort != "" {
fmt.Sscanf(r.MLXListenPort, "%d", &port)
}
address := "127.0.0.1"
c, cancel := context.WithCancel(context.Background())
defer cancel()
backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState)
if err != nil {
xlog.Warn("Could not find mlx-distributed backend from gallery, will try backend.py directly", "error", err)
}
go func() {
for {
hostfile := os.Getenv("MLX_DISTRIBUTED_HOSTFILE")
if hostfile == "" {
xlog.Info("Waiting for MLX_DISTRIBUTED_HOSTFILE to be set by P2P discovery...")
time.Sleep(2 * time.Second)
continue
}
xlog.Info("Starting mlx-distributed worker", "address", address, "port", port, "hostfile", hostfile)
cmd := buildMLXCommand(backendPath,
"--worker",
"--backend", r.MLXBackend,
"--hostfile", hostfile,
"--rank", "0",
)
cmd.Env = os.Environ()
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
if err := cmd.Run(); err != nil {
xlog.Error("mlx-distributed worker exited", "error", err)
}
time.Sleep(2 * time.Second)
}
}()
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.MLXWorkerID))
if err != nil {
return err
}
xlog.Info("MLX distributed worker registered on P2P network", "address", address, "port", port)
signals.RegisterGracefulTerminationHandler(func() {
cancel()
})
<-c.Done()
return nil
}

View File

@@ -82,7 +82,8 @@ type ApplicationConfig struct {
APIAddress string
TunnelCallback func(tunnels []string)
LlamaCPPTunnelCallback func(tunnels []string)
MLXTunnelCallback func(tunnels []string)
DisableRuntimeSettings bool
@@ -457,9 +458,15 @@ func WithContextSize(ctxSize int) AppOption {
}
}
func WithTunnelCallback(callback func(tunnels []string)) AppOption {
func WithLlamaCPPTunnelCallback(callback func(tunnels []string)) AppOption {
return func(o *ApplicationConfig) {
o.TunnelCallback = callback
o.LlamaCPPTunnelCallback = callback
}
}
func WithMLXTunnelCallback(callback func(tunnels []string)) AppOption {
return func(o *ApplicationConfig) {
o.MLXTunnelCallback = callback
}
}

View File

@@ -156,7 +156,7 @@ func (s *DiscoveryServer) retrieveNetworkData(c context.Context, ledger *blockch
for d := range data {
toScanForWorkers := false
cd := ClusterData{}
isWorkerCluster := d == p2p.WorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.WorkerID))
isWorkerCluster := d == p2p.LlamaCPPWorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.LlamaCPPWorkerID))
isFederatedCluster := d == p2p.FederatedID || (strings.Contains(d, "_") && strings.Contains(d, p2p.FederatedID))
switch {
case isWorkerCluster:

View File

@@ -15,8 +15,9 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
// Render index
return func(c echo.Context) error {
return c.JSON(200, schema.P2PNodesResponse{
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
LlamaCPPNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)),
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
MLXNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)),
})
}
}

View File

@@ -103,8 +103,9 @@ function StepNumber({ n, bg, color }) {
export default function P2P() {
const { addToast } = useOutletContext()
const [workers, setWorkers] = useState([])
const [mlxWorkers, setMlxWorkers] = useState([])
const [federation, setFederation] = useState([])
const [stats, setStats] = useState({ workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 } })
const [stats, setStats] = useState({ llama_cpp_workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 }, mlx_workers: { online: 0, total: 0 } })
const [loading, setLoading] = useState(true)
const [enabled, setEnabled] = useState(false)
const [token, setToken] = useState('')
@@ -129,7 +130,13 @@ export default function P2P() {
if (p2pToken) {
if (wRes.status === 'fulfilled') {
const data = wRes.value
setWorkers(data?.nodes || (Array.isArray(data) ? data : []))
// Handle both old format ({nodes: [...]}) and new grouped format ({llama_cpp: {nodes: [...]}, mlx: {nodes: [...]}})
if (data?.llama_cpp) {
setWorkers(data.llama_cpp.nodes || [])
setMlxWorkers(data.mlx?.nodes || [])
} else {
setWorkers(data?.nodes || (Array.isArray(data) ? data : []))
}
}
if (fRes.status === 'fulfilled') {
const data = fRes.value
@@ -274,8 +281,10 @@ export default function P2P() {
// ── P2P Enabled ──
const fedOnline = stats.federated?.online ?? 0
const fedTotal = stats.federated?.total ?? 0
const wrkOnline = stats.workers?.online ?? 0
const wrkTotal = stats.workers?.total ?? 0
const llamaOnline = stats.llama_cpp_workers?.online ?? 0
const llamaTotal = stats.llama_cpp_workers?.total ?? 0
const mlxOnline = stats.mlx_workers?.online ?? 0
const mlxTotal = stats.mlx_workers?.total ?? 0
return (
<div className="page">
@@ -401,7 +410,7 @@ export default function P2P() {
fontSize: '0.75rem',
color: activeTab === 'sharding' ? 'var(--color-accent)' : 'var(--color-text-muted)',
}}>
{wrkOnline}/{wrkTotal} workers
{llamaOnline + mlxOnline}/{llamaTotal + mlxTotal} workers
</div>
</div>
</div>
@@ -562,6 +571,21 @@ export default function P2P() {
borderRadius: 'var(--radius-lg)', overflow: 'hidden',
}}>
<div style={{ padding: 'var(--spacing-lg)', borderBottom: '1px solid var(--color-border-subtle)' }}>
<div style={{
background: 'var(--color-accent-light)', border: '1px solid rgba(139,92,246,0.3)',
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-sm) var(--spacing-md)',
fontSize: '0.8125rem', color: 'var(--color-text-secondary)', marginBottom: 'var(--spacing-md)',
}}>
<i className="fas fa-info-circle" style={{ color: 'var(--color-accent)', marginRight: 6 }} />
<strong>Different from federation:</strong> Federation distributes whole requests across instances. Model sharding splits a single model across machines for joint inference.
</div>
{/* ── llama.cpp RPC Workers Section ── */}
<h3 style={{ fontSize: '1.125rem', fontWeight: 700, marginBottom: 'var(--spacing-sm)', display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)' }}>
<i className="fas fa-microchip" style={{ color: 'var(--color-accent)' }} />
llama.cpp RPC Workers
</h3>
{/* Architecture diagram */}
<div style={{
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border-subtle)',
@@ -604,25 +628,15 @@ export default function P2P() {
</div>
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)', textAlign: 'center', marginTop: 'var(--spacing-sm)', lineHeight: 1.5 }}>
Model weights are <strong>split across RPC workers</strong>. Each worker holds a portion of the model layers in its memory (GPU or CPU).
The LocalAI instance orchestrates inference by communicating with all workers via RPC.
</p>
</div>
<div style={{
background: 'var(--color-accent-light)', border: '1px solid rgba(139,92,246,0.3)',
borderRadius: 'var(--radius-md)', padding: 'var(--spacing-sm) var(--spacing-md)',
fontSize: '0.8125rem', color: 'var(--color-text-secondary)', marginBottom: 'var(--spacing-md)',
}}>
<i className="fas fa-info-circle" style={{ color: 'var(--color-accent)', marginRight: 6 }} />
<strong>Different from federation:</strong> Federation distributes whole requests across instances. Model sharding splits a single model's weights across machines for joint inference. Currently only supported with <strong>llama.cpp</strong> based models.
</div>
{/* Status + nodes */}
{/* llama.cpp Status + nodes */}
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', marginBottom: 'var(--spacing-md)' }}>
<h3 style={{ fontSize: '1rem', fontWeight: 700 }}>Connected Workers</h3>
<h4 style={{ fontSize: '1rem', fontWeight: 600 }}>Connected Workers</h4>
<div style={{ fontSize: '1.25rem', fontWeight: 700 }}>
<span style={{ color: wrkOnline > 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{wrkOnline}</span>
<span style={{ color: 'var(--color-text-secondary)', fontSize: '1rem' }}>/{wrkTotal}</span>
<span style={{ color: llamaOnline > 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{llamaOnline}</span>
<span style={{ color: 'var(--color-text-secondary)', fontSize: '1rem' }}>/{llamaTotal}</span>
</div>
</div>
@@ -633,7 +647,7 @@ export default function P2P() {
borderRadius: 'var(--radius-lg)',
}}>
<i className="fas fa-puzzle-piece" style={{ fontSize: '2rem', color: 'var(--color-text-muted)', marginBottom: 'var(--spacing-sm)' }} />
<p style={{ fontWeight: 500, color: 'var(--color-text-secondary)' }}>No workers available</p>
<p style={{ fontWeight: 500, color: 'var(--color-text-secondary)' }}>No llama.cpp workers connected</p>
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-muted)', marginTop: 'var(--spacing-xs)' }}>Start workers to see them here</p>
</div>
) : (
@@ -645,17 +659,99 @@ export default function P2P() {
)}
</div>
{/* Setup Guide */}
{/* ── MLX Distributed Workers Section ── */}
<div style={{ padding: 'var(--spacing-lg)', borderBottom: '1px solid var(--color-border-subtle)' }}>
<h3 style={{ fontSize: '1.125rem', fontWeight: 700, marginBottom: 'var(--spacing-sm)', display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)' }}>
<i className="fas fa-apple-whole" style={{ color: 'rgba(245,158,11,1)' }} />
MLX Distributed Workers
</h3>
{/* MLX Architecture diagram */}
<div style={{
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border-subtle)',
borderRadius: 'var(--radius-lg)', padding: 'var(--spacing-md)', marginBottom: 'var(--spacing-md)',
}}>
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', gap: 'var(--spacing-md)', flexWrap: 'wrap' }}>
<div style={{ textAlign: 'center' }}>
<div style={{
width: 48, height: 48, borderRadius: 'var(--radius-md)',
background: 'var(--color-primary-light)', display: 'flex', alignItems: 'center', justifyContent: 'center',
margin: '0 auto var(--spacing-xs)', border: '2px solid var(--color-primary)',
}}>
<i className="fas fa-server" style={{ color: 'var(--color-primary)', fontSize: '1rem' }} />
</div>
<div style={{ fontSize: '0.75rem', fontWeight: 600 }}>LocalAI</div>
<div style={{ fontSize: '0.625rem', color: 'var(--color-text-muted)' }}>Rank 0</div>
</div>
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: '2px' }}>
<i className="fas fa-arrows-left-right" style={{ color: 'var(--color-text-muted)', fontSize: '0.875rem' }} />
<span style={{ fontSize: '0.625rem', color: 'var(--color-text-muted)' }}>Ring / JACCL</span>
</div>
<div style={{ textAlign: 'center' }}>
<div style={{ display: 'flex', gap: '4px', marginBottom: 'var(--spacing-xs)' }}>
{['Layers 1-16', 'Layers 17-32'].map((label, i) => (
<div key={i} style={{ textAlign: 'center' }}>
<div style={{
width: 64, height: 36, borderRadius: 'var(--radius-sm)',
background: 'rgba(245,158,11,0.1)', display: 'flex', alignItems: 'center', justifyContent: 'center',
border: '1px solid rgba(245,158,11,0.3)',
}}>
<i className="fas fa-microchip" style={{ color: 'rgba(245,158,11,1)', fontSize: '0.75rem' }} />
</div>
<div style={{ fontSize: '0.5625rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{label}</div>
</div>
))}
</div>
<div style={{ fontSize: '0.75rem', fontWeight: 600 }}>MLX Workers</div>
<div style={{ fontSize: '0.625rem', color: 'var(--color-text-muted)' }}>Pipeline parallel</div>
</div>
</div>
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)', textAlign: 'center', marginTop: 'var(--spacing-sm)', lineHeight: 1.5 }}>
MLX distributed uses <strong>native Apple Silicon communication</strong>. All nodes execute model code simultaneously via pipeline or tensor parallelism.
</p>
</div>
{/* MLX Status + nodes */}
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', marginBottom: 'var(--spacing-md)' }}>
<h4 style={{ fontSize: '1rem', fontWeight: 600 }}>Connected MLX Workers</h4>
<div style={{ fontSize: '1.25rem', fontWeight: 700 }}>
<span style={{ color: mlxOnline > 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{mlxOnline}</span>
<span style={{ color: 'var(--color-text-secondary)', fontSize: '1rem' }}>/{mlxTotal}</span>
</div>
</div>
{mlxWorkers.length === 0 ? (
<div style={{
textAlign: 'center', padding: 'var(--spacing-lg)',
background: 'var(--color-bg-primary)', border: '1px solid var(--color-border-subtle)',
borderRadius: 'var(--radius-lg)',
}}>
<i className="fas fa-apple-whole" style={{ fontSize: '2rem', color: 'var(--color-text-muted)', marginBottom: 'var(--spacing-sm)' }} />
<p style={{ fontWeight: 500, color: 'var(--color-text-secondary)' }}>No MLX workers connected</p>
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-muted)', marginTop: 'var(--spacing-xs)' }}>Start MLX workers on Apple Silicon Macs</p>
</div>
) : (
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(auto-fill, minmax(260px, 1fr))', gap: 'var(--spacing-md)' }}>
{mlxWorkers.map((node, i) => (
<NodeCard key={node.id || i} node={node} label={`MLX Rank ${i + 1}`} iconColor="rgba(245,158,11,1)" iconBg="rgba(245,158,11,0.1)" />
))}
</div>
)}
</div>
{/* Setup Guides */}
<div style={{ padding: 'var(--spacing-lg)' }}>
<h3 style={{ fontSize: '1.125rem', fontWeight: 700, marginBottom: 'var(--spacing-md)' }}>
<i className="fas fa-book" style={{ color: 'var(--color-accent)', marginRight: 'var(--spacing-sm)' }} />
Start a llama.cpp RPC Worker
Setup Workers
</h3>
<div style={{
background: 'var(--color-bg-primary)', borderRadius: 'var(--radius-lg)',
border: '1px solid var(--color-border-subtle)', padding: 'var(--spacing-lg)',
marginBottom: 'var(--spacing-md)',
}}>
<h4 style={{ fontSize: '1rem', fontWeight: 600, marginBottom: 'var(--spacing-sm)' }}>llama.cpp RPC Worker</h4>
<p style={{ color: 'var(--color-text-secondary)', fontSize: '0.875rem', marginBottom: 'var(--spacing-sm)' }}>
Each worker exposes its GPU/CPU memory as a shard for distributed model inference.
</p>
@@ -663,17 +759,24 @@ export default function P2P() {
command={`docker run -ti --net host \\\n -e TOKEN="${token}" \\\n --name local-ai-worker \\\n localai/localai:latest-cpu worker p2p-llama-cpp-rpc`}
addToast={addToast}
/>
<p style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem', marginTop: 'var(--spacing-sm)' }}>
Run this on each machine you want to contribute as a shard. The worker will automatically join the network and advertise its resources.
</p>
</div>
<p style={{ color: 'var(--color-text-secondary)', fontSize: '0.875rem', marginTop: 'var(--spacing-lg)' }}>
For GPU images and all available options, see the{' '}
<a href="https://localai.io/basics/container/" target="_blank" rel="noopener noreferrer"
style={{ color: 'var(--color-accent)' }}>Container images</a>
{' '}and{' '}
<a href="https://localai.io/features/distribute/#starting-workers" target="_blank" rel="noopener noreferrer"
style={{ color: 'var(--color-accent)' }}>Worker</a> docs.
<div style={{
background: 'var(--color-bg-primary)', borderRadius: 'var(--radius-lg)',
border: '1px solid rgba(245,158,11,0.3)', padding: 'var(--spacing-lg)',
}}>
<h4 style={{ fontSize: '1rem', fontWeight: 600, marginBottom: 'var(--spacing-sm)' }}>MLX Distributed Worker</h4>
<p style={{ color: 'var(--color-text-secondary)', fontSize: '0.875rem', marginBottom: 'var(--spacing-sm)' }}>
Run on Apple Silicon Macs to participate in distributed MLX inference via pipeline parallelism.
</p>
<CommandBlock
command={`docker run -ti --net host \\\n -e TOKEN="${token}" \\\n --name local-ai-mlx-worker \\\n localai/localai:latest-metal-darwin-arm64 worker p2p-mlx`}
addToast={addToast}
/>
<p style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem', marginTop: 'var(--spacing-sm)' }}>
For more information, see the{' '}
<a href="https://localai.io/features/mlx-distributed/" target="_blank" rel="noopener noreferrer"
style={{ color: 'rgba(245,158,11,1)' }}>MLX Distributed</a> docs.
</p>
</div>
</div>

View File

@@ -1041,11 +1041,12 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
// P2P APIs
app.GET("/api/p2p/workers", func(c echo.Context) error {
nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
llamaNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID))
mlxNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID))
nodesJSON := make([]map[string]interface{}, 0, len(nodes))
for _, n := range nodes {
nodesJSON = append(nodesJSON, map[string]interface{}{
llamaJSON := make([]map[string]any, 0, len(llamaNodes))
for _, n := range llamaNodes {
llamaJSON = append(llamaJSON, map[string]any{
"name": n.Name,
"id": n.ID,
"tunnelAddress": n.TunnelAddress,
@@ -1055,8 +1056,27 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
})
}
return c.JSON(200, map[string]interface{}{
"nodes": nodesJSON,
mlxJSON := make([]map[string]any, 0, len(mlxNodes))
for _, n := range mlxNodes {
mlxJSON = append(mlxJSON, map[string]any{
"name": n.Name,
"id": n.ID,
"tunnelAddress": n.TunnelAddress,
"serviceID": n.ServiceID,
"lastSeen": n.LastSeen,
"isOnline": n.IsOnline(),
})
}
return c.JSON(200, map[string]any{
"llama_cpp": map[string]any{
"nodes": llamaJSON,
},
"mlx": map[string]any{
"nodes": mlxJSON,
},
// Keep backward-compatible "nodes" key with llama.cpp workers
"nodes": llamaJSON,
})
})
@@ -1081,13 +1101,14 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
})
app.GET("/api/p2p/stats", func(c echo.Context) error {
workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))
llamaCPPNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID))
federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))
mlxWorkerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID))
workersOnline := 0
for _, n := range workerNodes {
llamaCPPOnline := 0
for _, n := range llamaCPPNodes {
if n.IsOnline() {
workersOnline++
llamaCPPOnline++
}
}
@@ -1098,15 +1119,26 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model
}
}
return c.JSON(200, map[string]interface{}{
"workers": map[string]interface{}{
"online": workersOnline,
"total": len(workerNodes),
mlxWorkersOnline := 0
for _, n := range mlxWorkerNodes {
if n.IsOnline() {
mlxWorkersOnline++
}
}
return c.JSON(200, map[string]any{
"llama_cpp_workers": map[string]any{
"online": llamaCPPOnline,
"total": len(llamaCPPNodes),
},
"federated": map[string]interface{}{
"federated": map[string]any{
"online": federatedOnline,
"total": len(federatedNodes),
},
"mlx_workers": map[string]any{
"online": mlxWorkersOnline,
"total": len(mlxWorkerNodes),
},
})
})

View File

@@ -262,8 +262,8 @@
</div>
<div class="text-right">
<div class="text-2xl font-bold">
<span :class="stats.workers.online > 0 ? 'text-[var(--color-success)]' : 'text-[var(--color-error)]'" x-text="stats.workers.online"></span>
<span class="text-[var(--color-text-secondary)] text-xl">/<span x-text="stats.workers.total"></span></span>
<span :class="stats.llama_cpp_workers.online > 0 ? 'text-[var(--color-success)]' : 'text-[var(--color-error)]'" x-text="stats.llama_cpp_workers.online"></span>
<span class="text-[var(--color-text-secondary)] text-xl">/<span x-text="stats.llama_cpp_workers.total"></span></span>
</div>
<p class="text-[var(--color-accent)] text-sm">workers</p>
</div>
@@ -469,8 +469,8 @@ docker run -ti --net host -e TOKEN="<span class="token">{{.P2PToken}}</span>" --
<div class="text-right">
<div class="text-sm text-[var(--color-text-secondary)] mb-1">Active Workers</div>
<div class="text-3xl font-bold">
<span :class="stats.workers.online > 0 ? 'text-[var(--color-accent)]' : 'text-[var(--color-error)]'" x-text="stats.workers.online"></span>
<span class="text-[var(--color-text-secondary)] text-xl">/<span x-text="stats.workers.total"></span></span>
<span :class="stats.llama_cpp_workers.online > 0 ? 'text-[var(--color-accent)]' : 'text-[var(--color-error)]'" x-text="stats.llama_cpp_workers.online"></span>
<span class="text-[var(--color-text-secondary)] text-xl">/<span x-text="stats.llama_cpp_workers.total"></span></span>
</div>
</div>
</div>
@@ -657,7 +657,7 @@ function p2pNetwork() {
workerNodes: [],
federationNodes: [],
stats: {
workers: { online: 0, total: 0 },
llama_cpp_workers: { online: 0, total: 0 },
federated: { online: 0, total: 0 }
},

View File

@@ -9,8 +9,9 @@ import (
)
const (
defaultServicesID = "services"
WorkerID = "worker"
defaultServicesID = "services"
LlamaCPPWorkerID = "worker"
MLXWorkerID = "mlx_worker"
)
var mu sync.Mutex

View File

@@ -136,8 +136,9 @@ func (d NodeData) IsOnline() bool {
}
type P2PNodesResponse struct {
Nodes []NodeData `json:"nodes" yaml:"nodes"`
LlamaCPPNodes []NodeData `json:"llama_cpp_nodes" yaml:"llama_cpp_nodes"`
FederatedNodes []NodeData `json:"federated_nodes" yaml:"federated_nodes"`
MLXNodes []NodeData `json:"mlx_nodes" yaml:"mlx_nodes"`
}
type SysInfoModel struct {

View File

@@ -0,0 +1,212 @@
+++
disableToc = false
title = "(experimental) MLX Distributed Inference"
weight = 18
url = '/features/mlx-distributed/'
+++
MLX distributed inference allows you to split large language models across multiple Apple Silicon Macs (or other devices) for joint inference. Unlike federation (which distributes whole requests), MLX distributed splits a single model's layers across machines so they all participate in every forward pass.
## How It Works
MLX distributed uses **pipeline parallelism** via the Ring backend: each node holds a slice of the model's layers. During inference, activations flow from rank 0 through each subsequent rank in a pipeline. The last rank gathers the final output.
For high-bandwidth setups (e.g., Thunderbolt-connected Macs), **JACCL** (tensor parallelism via RDMA) is also supported, where each rank holds all layers but with sharded weights.
## Prerequisites
- Two or more machines with MLX installed (Apple Silicon recommended)
- Network connectivity between all nodes (TCP for Ring, RDMA/Thunderbolt for JACCL)
- Same model accessible on all nodes (e.g., from Hugging Face cache)
## Quick Start with P2P
The simplest way to use MLX distributed is with LocalAI's P2P auto-discovery.
### 1. Start LocalAI with P2P
```bash
docker run -ti --net host \
--name local-ai \
localai/localai:latest-metal-darwin-arm64 run --p2p
```
This generates a network token. Copy it for the next step.
### 2. Start MLX Workers
On each additional Mac:
```bash
docker run -ti --net host \
-e TOKEN="<your-token>" \
--name local-ai-mlx-worker \
localai/localai:latest-metal-darwin-arm64 worker p2p-mlx
```
Workers auto-register on the P2P network. The LocalAI server discovers them and generates a hostfile for MLX distributed.
### 3. Use the Model
Load any MLX-compatible model. The `mlx-distributed` backend will automatically shard it across all available ranks:
```yaml
name: llama-distributed
backend: mlx-distributed
parameters:
model: mlx-community/Llama-3.2-1B-Instruct-4bit
```
## Model Configuration
The `mlx-distributed` backend is started automatically by LocalAI like any other backend. You configure distributed inference through the model YAML file using the `options` field:
### Ring Backend (TCP)
```yaml
name: llama-distributed
backend: mlx-distributed
parameters:
model: mlx-community/Llama-3.2-1B-Instruct-4bit
options:
- "hostfile:/path/to/hosts.json"
- "distributed_backend:ring"
```
The **hostfile** is a JSON array where entry `i` is the `"ip:port"` that **rank `i` listens on** for ring communication. All ranks must use the same hostfile so they know how to reach each other.
**Example:** Two Macs — Mac A (`192.168.1.10`) and Mac B (`192.168.1.11`):
```json
["192.168.1.10:5555", "192.168.1.11:5555"]
```
- Entry 0 (`192.168.1.10:5555`) — the address rank 0 (Mac A) listens on for ring communication
- Entry 1 (`192.168.1.11:5555`) — the address rank 1 (Mac B) listens on for ring communication
Port 5555 is arbitrary — use any available port, but it must be open in your firewall.
### JACCL Backend (RDMA/Thunderbolt)
```yaml
name: llama-distributed
backend: mlx-distributed
parameters:
model: mlx-community/Llama-3.2-1B-Instruct-4bit
options:
- "hostfile:/path/to/devices.json"
- "distributed_backend:jaccl"
```
The **device matrix** is a JSON 2D array describing the RDMA device name between each pair of ranks. The diagonal is `null` (a rank doesn't talk to itself):
```json
[
[null, "rdma_thunderbolt0"],
["rdma_thunderbolt0", null]
]
```
JACCL requires a **coordinator** — a TCP service that helps all ranks establish RDMA connections. Rank 0 (the LocalAI machine) is always the coordinator. Workers are told the coordinator address via their `--coordinator` CLI flag (see [Starting Workers](#jaccl-workers) below).
### Without hostfile (single-node)
If no `hostfile` option is set and no `MLX_DISTRIBUTED_HOSTFILE` environment variable exists, the backend runs as a regular single-node MLX backend. This is useful for testing or when you don't need distributed inference.
### Available Options
| Option | Description |
|--------|-------------|
| `hostfile` | Path to the hostfile JSON. Ring: array of `"ip:port"`. JACCL: device matrix. |
| `distributed_backend` | `ring` (default) or `jaccl` |
| `trust_remote_code` | Allow trust_remote_code for the tokenizer |
| `max_tokens` | Override default max generation tokens |
| `temperature` / `temp` | Sampling temperature |
| `top_p` | Top-p sampling |
These can also be set via environment variables (`MLX_DISTRIBUTED_HOSTFILE`, `MLX_DISTRIBUTED_BACKEND`) which are used as fallbacks when the model options don't specify them.
## Starting Workers
LocalAI starts the rank 0 process (gRPC server) automatically when the model is loaded. But you still need to start **worker processes** (ranks 1, 2, ...) on the other machines. These workers participate in every forward pass but don't serve any API — they wait for commands from rank 0.
### Ring Workers
On each worker machine, start a worker with the same hostfile:
```bash
local-ai worker mlx-distributed --hostfile hosts.json --rank 1
```
The `--rank` must match the worker's position in the hostfile. For example, if `hosts.json` is `["192.168.1.10:5555", "192.168.1.11:5555", "192.168.1.12:5555"]`, then:
- Rank 0: started automatically by LocalAI on `192.168.1.10`
- Rank 1: `local-ai worker mlx-distributed --hostfile hosts.json --rank 1` on `192.168.1.11`
- Rank 2: `local-ai worker mlx-distributed --hostfile hosts.json --rank 2` on `192.168.1.12`
### JACCL Workers
```bash
local-ai worker mlx-distributed \
--hostfile devices.json \
--rank 1 \
--backend jaccl \
--coordinator 192.168.1.10:5555
```
The `--coordinator` address is the IP of the machine running LocalAI (rank 0) with any available port. Rank 0 binds the coordinator service there; workers connect to it to establish RDMA connections.
### Worker Startup Order
Start workers **before** loading the model in LocalAI. When LocalAI sends the LoadModel request, rank 0 initializes `mx.distributed` which tries to connect to all ranks listed in the hostfile. If workers aren't running yet, it will time out.
## Advanced: Manual Rank 0
For advanced use cases, you can also run rank 0 manually as an external gRPC backend instead of letting LocalAI start it automatically:
```bash
# On Mac A: start rank 0 manually
local-ai worker mlx-distributed --hostfile hosts.json --rank 0 --addr 192.168.1.10:50051
# On Mac B: start rank 1
local-ai worker mlx-distributed --hostfile hosts.json --rank 1
# On any machine: start LocalAI pointing at rank 0
local-ai run --external-grpc-backends "mlx-distributed:192.168.1.10:50051"
```
Then use a model config with `backend: mlx-distributed` (no need for `hostfile` in options since rank 0 already has it from CLI args).
## CLI Reference
### `worker mlx-distributed`
Starts a worker or manual rank 0 process.
| Flag | Env | Default | Description |
|------|-----|---------|-------------|
| `--hostfile` | `MLX_DISTRIBUTED_HOSTFILE` | *(required)* | Path to hostfile JSON. Ring: array of `"ip:port"` where entry `i` is rank `i`'s listen address. JACCL: device matrix of RDMA device names. |
| `--rank` | `MLX_RANK` | *(required)* | Rank of this process (0 = gRPC server + ring participant, >0 = worker only) |
| `--backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | `ring` (TCP pipeline parallelism) or `jaccl` (RDMA tensor parallelism) |
| `--addr` | `MLX_DISTRIBUTED_ADDR` | `localhost:50051` | gRPC API listen address (rank 0 only, for LocalAI or external access) |
| `--coordinator` | `MLX_JACCL_COORDINATOR` | | JACCL coordinator `ip:port` — rank 0's address for RDMA setup (all ranks must use the same value) |
### `worker p2p-mlx`
P2P mode — auto-discovers peers and generates hostfile.
| Flag | Env | Default | Description |
|------|-----|---------|-------------|
| `--token` | `TOKEN` | *(required)* | P2P network token |
| `--mlx-listen-port` | `MLX_LISTEN_PORT` | `5555` | Port for MLX communication |
| `--mlx-backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | Backend type: `ring` or `jaccl` |
## Troubleshooting
- **All ranks download the model independently.** Each node auto-downloads from Hugging Face on first use via `mlx_lm.load()`. On rank 0 (started by LocalAI), models are downloaded to LocalAI's model directory (`HF_HOME` is set automatically). On workers, models go to the default HF cache (`~/.cache/huggingface/hub`) unless you set `HF_HOME` yourself.
- **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. Start workers before loading the model.
- **Rank assignment:** In P2P mode, rank 0 is always the LocalAI server. Worker ranks are assigned by sorting node IDs.
- **Performance:** Pipeline parallelism adds latency proportional to the number of ranks. For best results, use the fewest ranks needed to fit your model in memory.
## Acknowledgements
The MLX distributed auto-parallel sharding implementation is based on [exo](https://github.com/exo-explore/exo).