mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 13:42:20 -04:00
feat(mlx-distributed): add new MLX-distributed backend (#8801)
* feat(mlx-distributed): add new MLX-distributed backend Add new MLX distributed backend with support for both TCP and RDMA for model sharding. This implementation ties in the discovery implementation already in place, and re-uses the same P2P mechanism for the TCP MLX-distributed inferencing. The Auto-parallel implementation is inspired by Exo's ones (who have been added to acknowledgement for the great work!) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * expose a CLI to facilitate backend starting Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: make manual rank0 configurable via model configs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add missing features from mlx backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
734b6d391f
commit
a026277ab9
23
backend/python/mlx-distributed/Makefile
Normal file
23
backend/python/mlx-distributed/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: mlx-distributed
|
||||
mlx-distributed:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run:
|
||||
@echo "Running mlx-distributed..."
|
||||
bash run.sh
|
||||
@echo "mlx-distributed run."
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
@echo "Testing mlx-distributed..."
|
||||
bash test.sh
|
||||
@echo "mlx-distributed tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
509
backend/python/mlx-distributed/backend.py
Normal file
509
backend/python/mlx-distributed/backend.py
Normal file
@@ -0,0 +1,509 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MLX Distributed Inference Backend for LocalAI.
|
||||
|
||||
Two startup modes:
|
||||
|
||||
1. Server mode (started by LocalAI automatically):
|
||||
run.sh --addr localhost:50051
|
||||
Distributed config comes from LoadModel options or env vars.
|
||||
|
||||
2. Worker mode (started by CLI for remote ranks):
|
||||
run.sh --worker --hostfile hosts.json --rank 1 --backend ring
|
||||
Enters a loop waiting for commands from rank 0.
|
||||
"""
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None):
|
||||
"""Initialize MLX distributed runtime.
|
||||
|
||||
Ring: MLX_HOSTFILE points to a JSON array of "ip:port" strings. Each rank
|
||||
binds to its own entry (hostfile[rank]) and connects to neighbors for the
|
||||
ring pipeline.
|
||||
|
||||
JACCL: MLX_IBV_DEVICES points to a JSON 2D matrix of RDMA device names.
|
||||
MLX_JACCL_COORDINATOR is rank 0's ip:port where it runs a TCP service that
|
||||
helps all ranks establish RDMA connections.
|
||||
"""
|
||||
import mlx.core as mx
|
||||
|
||||
if backend == "ring":
|
||||
os.environ["MLX_HOSTFILE"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
return mx.distributed.init(backend="ring", strict=True)
|
||||
elif backend == "jaccl":
|
||||
os.environ["MLX_IBV_DEVICES"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
if coordinator:
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = coordinator
|
||||
return mx.distributed.init(backend="jaccl", strict=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
|
||||
def is_float(s):
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_int(s):
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def parse_options(options):
|
||||
"""Parse key:value option strings into a dict."""
|
||||
result = {}
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""gRPC servicer for distributed MLX inference (runs on rank 0).
|
||||
|
||||
When started by LocalAI (server mode), distributed init happens at
|
||||
LoadModel time using config from model options or environment variables.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.group = None
|
||||
self.dist_backend = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.coordinator = None
|
||||
self.options = {}
|
||||
self.lru_cache = None
|
||||
self.model_key = None
|
||||
self.max_kv_size = None
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
async def LoadModel(self, request, context):
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load
|
||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||
|
||||
print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr)
|
||||
|
||||
self.options = parse_options(request.Options)
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
# Get distributed config from model options, falling back to env vars.
|
||||
# If neither is set, run as single-node (no distributed).
|
||||
hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", ""))
|
||||
dist_backend = str(self.options.get("distributed_backend",
|
||||
os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring")))
|
||||
# JACCL coordinator: rank 0 reads from env (set by CLI --coordinator).
|
||||
# Not in model options — rank 0 is the coordinator, workers get
|
||||
# the address via their own --coordinator CLI flag.
|
||||
jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "")
|
||||
|
||||
if hostfile:
|
||||
from coordinator import DistributedCoordinator, CMD_LOAD_MODEL
|
||||
from sharding import pipeline_auto_parallel
|
||||
|
||||
print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr)
|
||||
self.dist_backend = dist_backend
|
||||
self.group = mlx_distributed_init(
|
||||
rank=0,
|
||||
hostfile=hostfile,
|
||||
backend=dist_backend,
|
||||
coordinator=jaccl_coordinator or None,
|
||||
)
|
||||
self.coordinator = DistributedCoordinator(self.group)
|
||||
self.coordinator.broadcast_command(CMD_LOAD_MODEL)
|
||||
self.coordinator.broadcast_model_name(request.Model)
|
||||
else:
|
||||
print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr)
|
||||
|
||||
# Build tokenizer config from request and options
|
||||
tokenizer_config = {}
|
||||
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
# Token overrides from options
|
||||
for key in ["eos_token", "pad_token", "bos_token", "unk_token",
|
||||
"sep_token", "cls_token", "mask_token"]:
|
||||
if key in self.options:
|
||||
tokenizer_config[key] = self.options[key]
|
||||
|
||||
if tokenizer_config:
|
||||
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
|
||||
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
|
||||
else:
|
||||
self.model, self.tokenizer = load(request.Model)
|
||||
|
||||
if self.group is not None:
|
||||
from sharding import pipeline_auto_parallel
|
||||
self.model = pipeline_auto_parallel(self.model, self.group)
|
||||
print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr)
|
||||
else:
|
||||
# Single-node: set up prompt cache for efficient generation
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
max_cache_entries = self.options.get("max_cache_entries", 10)
|
||||
self.max_kv_size = self.options.get("max_kv_size", None)
|
||||
self.model_key = request.Model
|
||||
self.lru_cache = ThreadSafeLRUPromptCache(
|
||||
max_size=max_cache_entries,
|
||||
can_trim_fn=can_trim_prompt_cache,
|
||||
trim_fn=trim_prompt_cache,
|
||||
)
|
||||
print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
print(f"[Rank 0] Error loading model: {err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading model: {err}")
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
async def Predict(self, request, context):
|
||||
prompt_cache = None
|
||||
cache_key = None
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
tokens = self._get_tokens_from_prompt(prompt_text)
|
||||
|
||||
if self.coordinator:
|
||||
from coordinator import CMD_GENERATE
|
||||
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
|
||||
self.coordinator.broadcast_tokens(tokens)
|
||||
|
||||
max_tokens, sampler_params = self._build_generation_params(request)
|
||||
|
||||
if self.coordinator:
|
||||
gen_params = self.coordinator.broadcast_generation_params(
|
||||
max_tokens=max_tokens,
|
||||
temperature=sampler_params.get('temp', 0.6),
|
||||
top_p=sampler_params.get('top_p', 1.0),
|
||||
)
|
||||
max_tokens = gen_params["max_tokens"]
|
||||
|
||||
sampler = make_sampler(**sampler_params)
|
||||
|
||||
# Use prompt cache in single-node mode
|
||||
gen_kwargs = {}
|
||||
if self.lru_cache is not None:
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
cache_key = list(tokens)
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
|
||||
self.model_key, cache_key
|
||||
)
|
||||
if prompt_cache is None:
|
||||
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
|
||||
remaining_tokens = cache_key
|
||||
gen_kwargs['prompt_cache'] = prompt_cache
|
||||
tokens = remaining_tokens if remaining_tokens else cache_key
|
||||
|
||||
generated = []
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
**gen_kwargs,
|
||||
):
|
||||
generated.append(response.text)
|
||||
if cache_key is not None:
|
||||
cache_key.append(response.token)
|
||||
|
||||
if self.lru_cache is not None and cache_key is not None:
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Generation failed: {str(e)}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
prompt_cache = None
|
||||
cache_key = None
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
tokens = self._get_tokens_from_prompt(prompt_text)
|
||||
|
||||
if self.coordinator:
|
||||
from coordinator import CMD_GENERATE
|
||||
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
|
||||
self.coordinator.broadcast_tokens(tokens)
|
||||
|
||||
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
|
||||
if self.coordinator:
|
||||
gen_params = self.coordinator.broadcast_generation_params(
|
||||
max_tokens=max_tokens,
|
||||
temperature=sampler_params.get('temp', 0.6),
|
||||
top_p=sampler_params.get('top_p', 1.0),
|
||||
)
|
||||
max_tokens = gen_params["max_tokens"]
|
||||
|
||||
sampler = make_sampler(**sampler_params)
|
||||
|
||||
# Use prompt cache in single-node mode
|
||||
gen_kwargs = {}
|
||||
if self.lru_cache is not None:
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
cache_key = list(tokens)
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
|
||||
self.model_key, cache_key
|
||||
)
|
||||
if prompt_cache is None:
|
||||
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
|
||||
remaining_tokens = cache_key
|
||||
gen_kwargs['prompt_cache'] = prompt_cache
|
||||
tokens = remaining_tokens if remaining_tokens else cache_key
|
||||
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
**gen_kwargs,
|
||||
):
|
||||
if cache_key is not None:
|
||||
cache_key.append(response.token)
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Streaming failed: {str(e)}")
|
||||
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
finally:
|
||||
if self.lru_cache is not None and prompt_cache is not None and cache_key is not None:
|
||||
try:
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
except Exception as e:
|
||||
print(f"Error inserting cache: {e}", file=sys.stderr)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
print("Embeddings not supported in MLX distributed backend", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Embeddings are not supported in the MLX distributed backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
def _prepare_prompt(self, request):
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages]
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return request.Prompt
|
||||
|
||||
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
|
||||
tokens = self.tokenizer.encode(prompt_text)
|
||||
if hasattr(tokens, 'tolist'):
|
||||
return tokens.tolist()
|
||||
return list(tokens)
|
||||
|
||||
def _build_generation_params(self, request, default_max_tokens=200):
|
||||
import mlx.core as mx
|
||||
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens)
|
||||
if max_tokens == 0:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
temp = getattr(request, 'Temperature', 0.0)
|
||||
if temp == 0.0:
|
||||
temp = 0.6
|
||||
|
||||
top_p = getattr(request, 'TopP', 0.0)
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0
|
||||
|
||||
sampler_params = {
|
||||
'temp': temp,
|
||||
'top_p': top_p,
|
||||
'min_p': getattr(request, 'MinP', 0.0),
|
||||
'top_k': getattr(request, 'TopK', 0),
|
||||
'xtc_threshold': 0.0,
|
||||
'xtc_probability': 0.0,
|
||||
}
|
||||
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
|
||||
if hasattr(self, 'options'):
|
||||
if 'max_tokens' in self.options:
|
||||
max_tokens = self.options['max_tokens']
|
||||
option_mapping = {
|
||||
'temp': 'temp',
|
||||
'temperature': 'temp',
|
||||
'top_p': 'top_p',
|
||||
'min_p': 'min_p',
|
||||
'top_k': 'top_k',
|
||||
'xtc_threshold': 'xtc_threshold',
|
||||
'xtc_probability': 'xtc_probability',
|
||||
}
|
||||
for opt_key, param_key in option_mapping.items():
|
||||
if opt_key in self.options:
|
||||
sampler_params[param_key] = self.options[opt_key]
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
# XTC special tokens
|
||||
xtc_special_tokens = []
|
||||
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
|
||||
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
|
||||
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
||||
xtc_special_tokens = [self.tokenizer.eos_token_id]
|
||||
try:
|
||||
newline_tokens = self.tokenizer.encode("\n")
|
||||
xtc_special_tokens.extend(newline_tokens)
|
||||
except:
|
||||
pass
|
||||
sampler_params['xtc_special_tokens'] = xtc_special_tokens
|
||||
|
||||
return max_tokens, sampler_params
|
||||
|
||||
|
||||
def run_worker(group):
|
||||
"""Worker loop for ranks > 0. Waits for commands from rank 0."""
|
||||
from mlx_lm import load, stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from coordinator import DistributedCoordinator, CMD_LOAD_MODEL, CMD_GENERATE, CMD_SHUTDOWN
|
||||
from sharding import pipeline_auto_parallel
|
||||
import mlx.core as mx
|
||||
|
||||
coordinator = DistributedCoordinator(group)
|
||||
model = None
|
||||
tokenizer = None
|
||||
|
||||
print(f"[Rank {group.rank()}] Worker started, waiting for commands...", file=sys.stderr)
|
||||
|
||||
while True:
|
||||
cmd, payload_size = coordinator.wait_for_command()
|
||||
|
||||
if cmd == CMD_LOAD_MODEL:
|
||||
model_name = coordinator.broadcast_model_name()
|
||||
print(f"[Rank {group.rank()}] Loading model: {model_name}", file=sys.stderr)
|
||||
model, tokenizer = load(model_name)
|
||||
model = pipeline_auto_parallel(model, group)
|
||||
print(f"[Rank {group.rank()}] Model loaded and sharded", file=sys.stderr)
|
||||
|
||||
elif cmd == CMD_GENERATE:
|
||||
if model is None:
|
||||
print(f"[Rank {group.rank()}] No model loaded, skipping generate", file=sys.stderr)
|
||||
continue
|
||||
|
||||
token_count = coordinator.broadcast_token_count(payload_size)
|
||||
tokens_array = coordinator.broadcast_tokens([0] * token_count)
|
||||
tokens = tokens_array.tolist()
|
||||
|
||||
gen_params = coordinator.broadcast_generation_params()
|
||||
|
||||
sampler = make_sampler(
|
||||
temp=gen_params["temperature"],
|
||||
top_p=gen_params["top_p"],
|
||||
)
|
||||
|
||||
for _ in stream_generate(
|
||||
model, tokenizer,
|
||||
prompt=tokens,
|
||||
max_tokens=gen_params["max_tokens"],
|
||||
sampler=sampler,
|
||||
):
|
||||
pass
|
||||
|
||||
elif cmd == CMD_SHUTDOWN:
|
||||
print(f"[Rank {group.rank()}] Shutting down", file=sys.stderr)
|
||||
break
|
||||
|
||||
|
||||
async def serve(address):
|
||||
server = grpc.aio.server(
|
||||
migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5)))
|
||||
|
||||
await server.start()
|
||||
print(f"[Rank 0] gRPC server listening on {address}", file=sys.stderr)
|
||||
await server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MLX Distributed Backend")
|
||||
parser.add_argument("--addr", default="localhost:50051",
|
||||
help="gRPC listen address (used by LocalAI to send requests)")
|
||||
parser.add_argument("--worker", action="store_true",
|
||||
help="Run in worker mode (for remote ranks started by CLI)")
|
||||
parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"],
|
||||
help="ring = TCP pipeline parallelism, jaccl = RDMA tensor parallelism")
|
||||
parser.add_argument("--hostfile", default=None,
|
||||
help="Path to hostfile JSON (required for --worker mode)")
|
||||
parser.add_argument("--rank", type=int, default=0,
|
||||
help="Rank of this process (0 = server, >0 = worker)")
|
||||
parser.add_argument("--coordinator", default=None,
|
||||
help="JACCL coordinator ip:port (jaccl backend only)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.worker:
|
||||
if not args.hostfile:
|
||||
print("Error: --hostfile is required in worker mode", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator)
|
||||
run_worker(group)
|
||||
else:
|
||||
# Server mode: started by LocalAI with just --addr.
|
||||
# Distributed init deferred to LoadModel (reads config from model options/env vars).
|
||||
asyncio.run(serve(args.addr))
|
||||
104
backend/python/mlx-distributed/coordinator.py
Normal file
104
backend/python/mlx-distributed/coordinator.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Distributed coordination using MLX distributed primitives.
|
||||
|
||||
Rank 0 broadcasts commands and tokens to all ranks via all_sum/all_gather.
|
||||
Worker ranks wait in a loop for commands from rank 0.
|
||||
"""
|
||||
import json
|
||||
import struct
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
CMD_IDLE = 0
|
||||
CMD_GENERATE = 1
|
||||
CMD_LOAD_MODEL = 2
|
||||
CMD_SHUTDOWN = -1
|
||||
|
||||
|
||||
class DistributedCoordinator:
|
||||
def __init__(self, group):
|
||||
self.group = group
|
||||
self.rank = group.rank()
|
||||
self.world_size = group.size()
|
||||
|
||||
def broadcast_command(self, cmd, payload_size=0):
|
||||
"""Rank 0 broadcasts a command to all ranks.
|
||||
|
||||
Uses all_sum with only rank 0 providing non-zero values so every
|
||||
rank receives the same command array.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
cmd_array = mx.array([cmd, payload_size], dtype=mx.int32)
|
||||
else:
|
||||
cmd_array = mx.zeros((2,), dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(cmd_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return int(result[0].item()), int(result[1].item())
|
||||
|
||||
def broadcast_tokens(self, tokens):
|
||||
"""Broadcast input token ids from rank 0 to all ranks.
|
||||
|
||||
Rank 0 provides the real token array; other ranks provide zeros of the
|
||||
same shape. ``all_sum`` ensures every rank ends up with identical data.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
token_array = mx.array(tokens, dtype=mx.int32)
|
||||
else:
|
||||
token_array = mx.zeros((len(tokens),), dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(token_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return result
|
||||
|
||||
def broadcast_token_count(self, count):
|
||||
"""Broadcast the number of tokens so workers can prepare a buffer."""
|
||||
if self.rank == 0:
|
||||
count_array = mx.array([count], dtype=mx.int32)
|
||||
else:
|
||||
count_array = mx.zeros((1,), dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(count_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return int(result[0].item())
|
||||
|
||||
def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0):
|
||||
"""Broadcast generation parameters from rank 0."""
|
||||
if self.rank == 0:
|
||||
params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32)
|
||||
else:
|
||||
params = mx.zeros((3,), dtype=mx.float32)
|
||||
result = mx.distributed.all_sum(params, group=self.group)
|
||||
mx.eval(result)
|
||||
return {
|
||||
"max_tokens": int(result[0].item()),
|
||||
"temperature": float(result[1].item()),
|
||||
"top_p": float(result[2].item()),
|
||||
}
|
||||
|
||||
def wait_for_command(self):
|
||||
"""Worker ranks block here until rank 0 broadcasts a command."""
|
||||
return self.broadcast_command(CMD_IDLE, 0)
|
||||
|
||||
def broadcast_model_name(self, model_name=""):
|
||||
"""Broadcast model name string from rank 0 to all ranks.
|
||||
|
||||
Encodes the model name as int32 codepoints so it can travel via
|
||||
all_sum.
|
||||
"""
|
||||
if self.rank == 0:
|
||||
encoded = [ord(c) for c in model_name]
|
||||
# First broadcast the length
|
||||
length = self.broadcast_token_count(len(encoded))
|
||||
if length > 0:
|
||||
name_array = mx.array(encoded, dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(name_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return model_name
|
||||
return ""
|
||||
else:
|
||||
length = self.broadcast_token_count(0)
|
||||
if length > 0:
|
||||
name_array = mx.zeros((length,), dtype=mx.int32)
|
||||
result = mx.distributed.all_sum(name_array, group=self.group)
|
||||
mx.eval(result)
|
||||
return "".join(chr(int(c.item())) for c in result)
|
||||
return ""
|
||||
15
backend/python/mlx-distributed/install.sh
Normal file
15
backend/python/mlx-distributed/install.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
USE_PIP=true
|
||||
PYTHON_VERSION=""
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
266
backend/python/mlx-distributed/mlx_cache.py
Normal file
266
backend/python/mlx-distributed/mlx_cache.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Thread-safe LRU prompt cache for MLX-based backends.
|
||||
|
||||
Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.)
|
||||
with thread-safety additions for LocalAI's gRPC backend.
|
||||
|
||||
Usage:
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
|
||||
# In LoadModel:
|
||||
self.lru_cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
# In Predict/PredictStream:
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens)
|
||||
# ... generate ...
|
||||
self.lru_cache.insert_cache(model_key, tokens, prompt_cache)
|
||||
"""
|
||||
import copy
|
||||
import threading
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cache entry with reference counting."""
|
||||
prompt_cache: List[Any]
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Result of searching the cache trie."""
|
||||
model: Any
|
||||
exact: Optional[List[int]]
|
||||
shorter: Optional[List[int]]
|
||||
longer: Optional[List[int]]
|
||||
common_prefix: int
|
||||
|
||||
|
||||
class ThreadSafeLRUPromptCache:
|
||||
"""
|
||||
Thread-safe LRU cache with prefix matching for prompt KV caches.
|
||||
|
||||
This cache stores KV caches keyed by token sequences and supports:
|
||||
- Exact match: Return the cache for the exact token sequence
|
||||
- Shorter prefix match: Return a cache for a prefix of the tokens
|
||||
- Longer prefix match: If a longer sequence is cached and can be trimmed
|
||||
- LRU eviction: When max_size is exceeded, evict least recently used
|
||||
|
||||
Thread safety is provided via a threading.Lock that protects all
|
||||
cache operations.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of cache entries (default: 10)
|
||||
can_trim_fn: Optional function to check if a cache can be trimmed
|
||||
trim_fn: Optional function to trim a cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 10,
|
||||
can_trim_fn: Optional[Any] = None,
|
||||
trim_fn: Optional[Any] = None,
|
||||
):
|
||||
self.max_size = max_size
|
||||
self._cache = {}
|
||||
self._lru = deque()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Optional trim functions (for longer prefix reuse)
|
||||
self._can_trim_fn = can_trim_fn
|
||||
self._trim_fn = trim_fn
|
||||
|
||||
def _search(self, model, tokens: List[int]) -> SearchResult:
|
||||
"""
|
||||
Search the cache for a prompt cache. Return exact or close match.
|
||||
|
||||
The cache is organized as a trie where each node is keyed by a token.
|
||||
This allows efficient prefix matching.
|
||||
"""
|
||||
if model not in self._cache:
|
||||
return SearchResult(model, None, None, None, 0)
|
||||
|
||||
current = self._cache[model]
|
||||
last_cache_index = -1
|
||||
index = 0
|
||||
|
||||
# Traverse the trie following the token sequence
|
||||
while index < len(tokens) and tokens[index] in current:
|
||||
current = current[tokens[index]]
|
||||
if "cache" in current:
|
||||
last_cache_index = index
|
||||
index += 1
|
||||
|
||||
# Exact match - no need to search for longer or shorter caches
|
||||
if last_cache_index == len(tokens) - 1:
|
||||
return SearchResult(model, tuple(tokens), None, None, 0)
|
||||
|
||||
# Find the shorter cache (a prefix that has a cache)
|
||||
# Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior.
|
||||
# Single-token prefixes are not matched, which allows longer cached
|
||||
# sequences to be preferred for trimming. This is acceptable because
|
||||
# real prompts with chat templates are always many tokens.
|
||||
shorter = None
|
||||
if last_cache_index > 0:
|
||||
shorter = tuple(tokens[: last_cache_index + 1])
|
||||
|
||||
# Check for caches that are longer than our token sequence
|
||||
longer = None
|
||||
common_prefix = index
|
||||
if index > 0 and last_cache_index <= 0:
|
||||
best = None
|
||||
stack = [(current, [])]
|
||||
while stack:
|
||||
current, extra = stack.pop()
|
||||
if "cache" in current:
|
||||
if best is None or len(extra) < len(best):
|
||||
best = extra
|
||||
else:
|
||||
for tok in current:
|
||||
stack.append((current[tok], extra + [tok]))
|
||||
if best is not None:
|
||||
longer = tuple(tokens[:index] + best)
|
||||
|
||||
return SearchResult(model, None, shorter, longer, common_prefix)
|
||||
|
||||
def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
||||
"""Get a cache entry by traversing the trie."""
|
||||
current = self._cache[model]
|
||||
for tok in tokens:
|
||||
current = current[tok]
|
||||
return current["cache"]
|
||||
|
||||
def _delete(self, model, tokens: Tuple[int, ...]) -> None:
|
||||
"""Delete a cache entry and clean up empty trie nodes."""
|
||||
path = [self._cache[model]]
|
||||
for tok in tokens:
|
||||
path.append(path[-1][tok])
|
||||
del path[-1]["cache"]
|
||||
|
||||
# Clean up empty nodes bottom-up
|
||||
for i in reversed(range(len(tokens))):
|
||||
d_prev, d, t = path[i], path[i + 1], tokens[i]
|
||||
if len(d) > 0:
|
||||
break
|
||||
del d_prev[t]
|
||||
|
||||
def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
||||
"""
|
||||
Extract a cache entry for exclusive use.
|
||||
|
||||
If the entry has count > 1, deep copy and decrement.
|
||||
If count == 1, remove from cache entirely.
|
||||
"""
|
||||
cache_entry = self._get(model, tokens)
|
||||
if cache_entry.count == 1:
|
||||
self._delete(model, tokens)
|
||||
self._lru.remove((model, tokens))
|
||||
return cache_entry
|
||||
|
||||
cache_entry.count -= 1
|
||||
return CacheEntry(
|
||||
copy.deepcopy(cache_entry.prompt_cache),
|
||||
1,
|
||||
)
|
||||
|
||||
def fetch_nearest_cache(
|
||||
self, model, tokens: List[int]
|
||||
) -> Tuple[Optional[List[Any]], List[int]]:
|
||||
"""
|
||||
Fetch the nearest cache for the given token sequence.
|
||||
|
||||
Thread-safe. Returns (cache, remaining_tokens) where:
|
||||
- cache: The KV cache to use (or None if no cache found)
|
||||
- remaining_tokens: Tokens that still need to be processed
|
||||
|
||||
Args:
|
||||
model: Model identifier (used to namespace caches)
|
||||
tokens: The full token sequence for the prompt
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_cache, remaining_tokens)
|
||||
"""
|
||||
with self._lock:
|
||||
tokens_tuple = tuple(tokens)
|
||||
result = self._search(model, tokens)
|
||||
|
||||
# Exact match - extract and return
|
||||
if result.exact is not None:
|
||||
cache_entry = self._extract(result.model, result.exact)
|
||||
return cache_entry.prompt_cache, []
|
||||
|
||||
# Shorter prefix match - extract and return remaining
|
||||
if result.shorter is not None:
|
||||
cache_entry = self._extract(result.model, result.shorter)
|
||||
prefix_len = len(result.shorter)
|
||||
return cache_entry.prompt_cache, list(tokens[prefix_len:])
|
||||
|
||||
# Longer prefix match - try to trim if possible
|
||||
if result.longer is not None and self._can_trim_fn is not None:
|
||||
cache_entry = self._get(result.model, result.longer)
|
||||
if self._can_trim_fn(cache_entry.prompt_cache):
|
||||
# Deep copy and trim
|
||||
trimmed_cache = copy.deepcopy(cache_entry.prompt_cache)
|
||||
prefix = min(len(tokens) - 1, result.common_prefix)
|
||||
num_to_trim = len(result.longer) - prefix
|
||||
if self._trim_fn is not None:
|
||||
self._trim_fn(trimmed_cache, num_to_trim)
|
||||
return trimmed_cache, list(tokens[prefix:])
|
||||
|
||||
# No match found
|
||||
return None, list(tokens)
|
||||
|
||||
def insert_cache(
|
||||
self, model, tokens: List[int], prompt_cache: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
Insert a cache entry after generation completes.
|
||||
|
||||
Thread-safe. Handles LRU eviction if max_size is exceeded.
|
||||
|
||||
Args:
|
||||
model: Model identifier (used to namespace caches)
|
||||
tokens: The full token sequence (prompt + generated)
|
||||
prompt_cache: The KV cache to store
|
||||
"""
|
||||
with self._lock:
|
||||
tokens_tuple = tuple(tokens)
|
||||
|
||||
if model not in self._cache:
|
||||
self._cache[model] = {}
|
||||
current = self._cache[model]
|
||||
|
||||
# Build trie path
|
||||
for tok in tokens_tuple:
|
||||
if tok not in current:
|
||||
current[tok] = {}
|
||||
current = current[tok]
|
||||
|
||||
# Update or create entry
|
||||
if "cache" in current:
|
||||
current["cache"].count += 1
|
||||
self._lru.remove((model, tokens_tuple))
|
||||
else:
|
||||
current["cache"] = CacheEntry(prompt_cache, 1)
|
||||
|
||||
# Update LRU order
|
||||
self._lru.append((model, tokens_tuple))
|
||||
|
||||
# Evict if over capacity
|
||||
if len(self._lru) > self.max_size:
|
||||
evict_model, evict_tokens = self._lru.popleft()
|
||||
self._delete(evict_model, evict_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache entries. Thread-safe."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._lru.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of cache entries. Thread-safe."""
|
||||
with self._lock:
|
||||
return len(self._lru)
|
||||
2
backend/python/mlx-distributed/requirements-cpu.txt
Normal file
2
backend/python/mlx-distributed/requirements-cpu.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cpu]
|
||||
2
backend/python/mlx-distributed/requirements-cublas12.txt
Normal file
2
backend/python/mlx-distributed/requirements-cublas12.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cuda12]
|
||||
2
backend/python/mlx-distributed/requirements-cublas13.txt
Normal file
2
backend/python/mlx-distributed/requirements-cublas13.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cuda13]
|
||||
2
backend/python/mlx-distributed/requirements-l4t12.txt
Normal file
2
backend/python/mlx-distributed/requirements-l4t12.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cuda12]
|
||||
2
backend/python/mlx-distributed/requirements-l4t13.txt
Normal file
2
backend/python/mlx-distributed/requirements-l4t13.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
mlx-lm
|
||||
mlx[cuda13]
|
||||
1
backend/python/mlx-distributed/requirements-mps.txt
Normal file
1
backend/python/mlx-distributed/requirements-mps.txt
Normal file
@@ -0,0 +1 @@
|
||||
mlx-lm
|
||||
4
backend/python/mlx-distributed/requirements.txt
Normal file
4
backend/python/mlx-distributed/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
11
backend/python/mlx-distributed/run.sh
Normal file
11
backend/python/mlx-distributed/run.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
136
backend/python/mlx-distributed/sharding.py
Normal file
136
backend/python/mlx-distributed/sharding.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Auto-parallelism for MLX distributed inference.
|
||||
|
||||
Provides pipeline parallelism (Ring backend) by wrapping model layers with
|
||||
distributed send/recv operations. Ported from exo's auto_parallel.py with
|
||||
simplifications for LocalAI's use case.
|
||||
"""
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PipelineFirstLayer(nn.Module):
|
||||
"""Wraps the first layer on each rank to receive from the previous rank."""
|
||||
|
||||
def __init__(self, original_layer, rank, group):
|
||||
super().__init__()
|
||||
dict.__setitem__(self, "_original_layer", original_layer)
|
||||
self.rank = rank
|
||||
self.group = group
|
||||
|
||||
@property
|
||||
def original_layer(self):
|
||||
return self["_original_layer"]
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self["_original_layer"], name)
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
if self.rank != 0:
|
||||
mx.eval(x)
|
||||
x = mx.distributed.recv_like(x, self.rank - 1, group=self.group)
|
||||
mx.eval(x)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
class PipelineLastLayer(nn.Module):
|
||||
"""Wraps the last layer on each rank to send to the next rank."""
|
||||
|
||||
def __init__(self, original_layer, rank, world_size, group):
|
||||
super().__init__()
|
||||
dict.__setitem__(self, "_original_layer", original_layer)
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.group = group
|
||||
|
||||
@property
|
||||
def original_layer(self):
|
||||
return self["_original_layer"]
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self["_original_layer"], name)
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
output = self.original_layer(x, *args, **kwargs)
|
||||
mx.eval(output)
|
||||
if self.rank != self.world_size - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.rank + 1) % self.world_size, group=self.group
|
||||
)
|
||||
mx.eval(output)
|
||||
# Gather output from all ranks so every rank has the final result
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
]
|
||||
mx.eval(output)
|
||||
return output
|
||||
|
||||
|
||||
def get_inner_model(model):
|
||||
"""Get the inner model (model.model or model.transformer)."""
|
||||
for attr in ("model", "transformer"):
|
||||
inner = getattr(model, attr, None)
|
||||
if isinstance(inner, nn.Module):
|
||||
# Some models have model.model (e.g. language_model.model)
|
||||
inner_inner = getattr(inner, "model", None)
|
||||
if isinstance(inner_inner, nn.Module):
|
||||
return inner_inner
|
||||
return inner
|
||||
raise ValueError("Model must have a 'model' or 'transformer' attribute")
|
||||
|
||||
|
||||
def get_layers(inner_model):
|
||||
"""Get the list of transformer layers."""
|
||||
for attr in ("layers", "h"):
|
||||
layers = getattr(inner_model, attr, None)
|
||||
if layers is not None:
|
||||
return layers
|
||||
raise ValueError("Model must have a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
def pipeline_auto_parallel(model, group, start_layer=None, end_layer=None):
|
||||
"""Apply pipeline parallelism to a model.
|
||||
|
||||
Each rank only keeps its slice of layers. The first layer receives from
|
||||
the previous rank, and the last layer sends to the next rank.
|
||||
|
||||
Args:
|
||||
model: The MLX model (must have model.layers or similar)
|
||||
group: The distributed group
|
||||
start_layer: First layer index for this rank (auto-computed if None)
|
||||
end_layer: Last layer index (exclusive) for this rank (auto-computed if None)
|
||||
"""
|
||||
rank = group.rank()
|
||||
world_size = group.size()
|
||||
|
||||
inner = get_inner_model(model)
|
||||
layers = list(get_layers(inner))
|
||||
total_layers = len(layers)
|
||||
|
||||
if start_layer is None or end_layer is None:
|
||||
layers_per_rank = total_layers // world_size
|
||||
remainder = total_layers % world_size
|
||||
start_layer = rank * layers_per_rank + min(rank, remainder)
|
||||
end_layer = start_layer + layers_per_rank + (1 if rank < remainder else 0)
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
for layer in layers:
|
||||
mx.eval(layer)
|
||||
|
||||
# Wrap first and last layers
|
||||
layers[0] = PipelineFirstLayer(layers[0], rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(layers[-1], rank, world_size, group=group)
|
||||
|
||||
# Replace layers on the inner model
|
||||
if hasattr(inner, "layers"):
|
||||
inner.layers = layers
|
||||
elif hasattr(inner, "h"):
|
||||
inner.h = layers
|
||||
|
||||
return model
|
||||
87
backend/python/mlx-distributed/test.py
Normal file
87
backend/python/mlx-distributed/test.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_text(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("text service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_sampling_params(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="The capital of France is",
|
||||
TopP=0.8,
|
||||
Tokens=50,
|
||||
Temperature=0.7,
|
||||
TopK=40,
|
||||
MinP=0.05,
|
||||
Seed=42,
|
||||
)
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("sampling params service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
12
backend/python/mlx-distributed/test.sh
Normal file
12
backend/python/mlx-distributed/test.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
Reference in New Issue
Block a user