Files
LocalAI/docs/content/features/mlx-distributed.md
Ettore Di Giacinto a026277ab9 feat(mlx-distributed): add new MLX-distributed backend (#8801)
* feat(mlx-distributed): add new MLX-distributed backend

Add new MLX distributed backend with support for both TCP and RDMA for
model sharding.

This implementation ties in the discovery implementation already in
place, and re-uses the same P2P mechanism for the TCP MLX-distributed
inferencing.

The Auto-parallel implementation is inspired by Exo's
ones (who have been added to acknowledgement for the great work!)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* expose a CLI to facilitate backend starting

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: make manual rank0 configurable via model configs

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add missing features from mlx backend

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Apply suggestion from @mudler

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-09 17:29:32 +01:00

213 lines
8.9 KiB
Markdown

+++
disableToc = false
title = "(experimental) MLX Distributed Inference"
weight = 18
url = '/features/mlx-distributed/'
+++
MLX distributed inference allows you to split large language models across multiple Apple Silicon Macs (or other devices) for joint inference. Unlike federation (which distributes whole requests), MLX distributed splits a single model's layers across machines so they all participate in every forward pass.
## How It Works
MLX distributed uses **pipeline parallelism** via the Ring backend: each node holds a slice of the model's layers. During inference, activations flow from rank 0 through each subsequent rank in a pipeline. The last rank gathers the final output.
For high-bandwidth setups (e.g., Thunderbolt-connected Macs), **JACCL** (tensor parallelism via RDMA) is also supported, where each rank holds all layers but with sharded weights.
## Prerequisites
- Two or more machines with MLX installed (Apple Silicon recommended)
- Network connectivity between all nodes (TCP for Ring, RDMA/Thunderbolt for JACCL)
- Same model accessible on all nodes (e.g., from Hugging Face cache)
## Quick Start with P2P
The simplest way to use MLX distributed is with LocalAI's P2P auto-discovery.
### 1. Start LocalAI with P2P
```bash
docker run -ti --net host \
--name local-ai \
localai/localai:latest-metal-darwin-arm64 run --p2p
```
This generates a network token. Copy it for the next step.
### 2. Start MLX Workers
On each additional Mac:
```bash
docker run -ti --net host \
-e TOKEN="<your-token>" \
--name local-ai-mlx-worker \
localai/localai:latest-metal-darwin-arm64 worker p2p-mlx
```
Workers auto-register on the P2P network. The LocalAI server discovers them and generates a hostfile for MLX distributed.
### 3. Use the Model
Load any MLX-compatible model. The `mlx-distributed` backend will automatically shard it across all available ranks:
```yaml
name: llama-distributed
backend: mlx-distributed
parameters:
model: mlx-community/Llama-3.2-1B-Instruct-4bit
```
## Model Configuration
The `mlx-distributed` backend is started automatically by LocalAI like any other backend. You configure distributed inference through the model YAML file using the `options` field:
### Ring Backend (TCP)
```yaml
name: llama-distributed
backend: mlx-distributed
parameters:
model: mlx-community/Llama-3.2-1B-Instruct-4bit
options:
- "hostfile:/path/to/hosts.json"
- "distributed_backend:ring"
```
The **hostfile** is a JSON array where entry `i` is the `"ip:port"` that **rank `i` listens on** for ring communication. All ranks must use the same hostfile so they know how to reach each other.
**Example:** Two Macs — Mac A (`192.168.1.10`) and Mac B (`192.168.1.11`):
```json
["192.168.1.10:5555", "192.168.1.11:5555"]
```
- Entry 0 (`192.168.1.10:5555`) — the address rank 0 (Mac A) listens on for ring communication
- Entry 1 (`192.168.1.11:5555`) — the address rank 1 (Mac B) listens on for ring communication
Port 5555 is arbitrary — use any available port, but it must be open in your firewall.
### JACCL Backend (RDMA/Thunderbolt)
```yaml
name: llama-distributed
backend: mlx-distributed
parameters:
model: mlx-community/Llama-3.2-1B-Instruct-4bit
options:
- "hostfile:/path/to/devices.json"
- "distributed_backend:jaccl"
```
The **device matrix** is a JSON 2D array describing the RDMA device name between each pair of ranks. The diagonal is `null` (a rank doesn't talk to itself):
```json
[
[null, "rdma_thunderbolt0"],
["rdma_thunderbolt0", null]
]
```
JACCL requires a **coordinator** — a TCP service that helps all ranks establish RDMA connections. Rank 0 (the LocalAI machine) is always the coordinator. Workers are told the coordinator address via their `--coordinator` CLI flag (see [Starting Workers](#jaccl-workers) below).
### Without hostfile (single-node)
If no `hostfile` option is set and no `MLX_DISTRIBUTED_HOSTFILE` environment variable exists, the backend runs as a regular single-node MLX backend. This is useful for testing or when you don't need distributed inference.
### Available Options
| Option | Description |
|--------|-------------|
| `hostfile` | Path to the hostfile JSON. Ring: array of `"ip:port"`. JACCL: device matrix. |
| `distributed_backend` | `ring` (default) or `jaccl` |
| `trust_remote_code` | Allow trust_remote_code for the tokenizer |
| `max_tokens` | Override default max generation tokens |
| `temperature` / `temp` | Sampling temperature |
| `top_p` | Top-p sampling |
These can also be set via environment variables (`MLX_DISTRIBUTED_HOSTFILE`, `MLX_DISTRIBUTED_BACKEND`) which are used as fallbacks when the model options don't specify them.
## Starting Workers
LocalAI starts the rank 0 process (gRPC server) automatically when the model is loaded. But you still need to start **worker processes** (ranks 1, 2, ...) on the other machines. These workers participate in every forward pass but don't serve any API — they wait for commands from rank 0.
### Ring Workers
On each worker machine, start a worker with the same hostfile:
```bash
local-ai worker mlx-distributed --hostfile hosts.json --rank 1
```
The `--rank` must match the worker's position in the hostfile. For example, if `hosts.json` is `["192.168.1.10:5555", "192.168.1.11:5555", "192.168.1.12:5555"]`, then:
- Rank 0: started automatically by LocalAI on `192.168.1.10`
- Rank 1: `local-ai worker mlx-distributed --hostfile hosts.json --rank 1` on `192.168.1.11`
- Rank 2: `local-ai worker mlx-distributed --hostfile hosts.json --rank 2` on `192.168.1.12`
### JACCL Workers
```bash
local-ai worker mlx-distributed \
--hostfile devices.json \
--rank 1 \
--backend jaccl \
--coordinator 192.168.1.10:5555
```
The `--coordinator` address is the IP of the machine running LocalAI (rank 0) with any available port. Rank 0 binds the coordinator service there; workers connect to it to establish RDMA connections.
### Worker Startup Order
Start workers **before** loading the model in LocalAI. When LocalAI sends the LoadModel request, rank 0 initializes `mx.distributed` which tries to connect to all ranks listed in the hostfile. If workers aren't running yet, it will time out.
## Advanced: Manual Rank 0
For advanced use cases, you can also run rank 0 manually as an external gRPC backend instead of letting LocalAI start it automatically:
```bash
# On Mac A: start rank 0 manually
local-ai worker mlx-distributed --hostfile hosts.json --rank 0 --addr 192.168.1.10:50051
# On Mac B: start rank 1
local-ai worker mlx-distributed --hostfile hosts.json --rank 1
# On any machine: start LocalAI pointing at rank 0
local-ai run --external-grpc-backends "mlx-distributed:192.168.1.10:50051"
```
Then use a model config with `backend: mlx-distributed` (no need for `hostfile` in options since rank 0 already has it from CLI args).
## CLI Reference
### `worker mlx-distributed`
Starts a worker or manual rank 0 process.
| Flag | Env | Default | Description |
|------|-----|---------|-------------|
| `--hostfile` | `MLX_DISTRIBUTED_HOSTFILE` | *(required)* | Path to hostfile JSON. Ring: array of `"ip:port"` where entry `i` is rank `i`'s listen address. JACCL: device matrix of RDMA device names. |
| `--rank` | `MLX_RANK` | *(required)* | Rank of this process (0 = gRPC server + ring participant, >0 = worker only) |
| `--backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | `ring` (TCP pipeline parallelism) or `jaccl` (RDMA tensor parallelism) |
| `--addr` | `MLX_DISTRIBUTED_ADDR` | `localhost:50051` | gRPC API listen address (rank 0 only, for LocalAI or external access) |
| `--coordinator` | `MLX_JACCL_COORDINATOR` | | JACCL coordinator `ip:port` — rank 0's address for RDMA setup (all ranks must use the same value) |
### `worker p2p-mlx`
P2P mode — auto-discovers peers and generates hostfile.
| Flag | Env | Default | Description |
|------|-----|---------|-------------|
| `--token` | `TOKEN` | *(required)* | P2P network token |
| `--mlx-listen-port` | `MLX_LISTEN_PORT` | `5555` | Port for MLX communication |
| `--mlx-backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | Backend type: `ring` or `jaccl` |
## Troubleshooting
- **All ranks download the model independently.** Each node auto-downloads from Hugging Face on first use via `mlx_lm.load()`. On rank 0 (started by LocalAI), models are downloaded to LocalAI's model directory (`HF_HOME` is set automatically). On workers, models go to the default HF cache (`~/.cache/huggingface/hub`) unless you set `HF_HOME` yourself.
- **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. Start workers before loading the model.
- **Rank assignment:** In P2P mode, rank 0 is always the LocalAI server. Worker ranks are assigned by sorting node IDs.
- **Performance:** Pipeline parallelism adds latency proportional to the number of ranks. For best results, use the fewest ranks needed to fit your model in memory.
## Acknowledgements
The MLX distributed auto-parallel sharding implementation is based on [exo](https://github.com/exo-explore/exo).