mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-02 14:16:02 -04:00
feat(mlx-distributed): add new MLX-distributed backend (#8801)
* feat(mlx-distributed): add new MLX-distributed backend Add new MLX distributed backend with support for both TCP and RDMA for model sharding. This implementation ties in the discovery implementation already in place, and re-uses the same P2P mechanism for the TCP MLX-distributed inferencing. The Auto-parallel implementation is inspired by Exo's ones (who have been added to acknowledgement for the great work!) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * expose a CLI to facilitate backend starting Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: make manual rank0 configurable via model configs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add missing features from mlx backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
734b6d391f
commit
a026277ab9
87
backend/python/mlx-distributed/test.py
Normal file
87
backend/python/mlx-distributed/test.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_text(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("text service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_sampling_params(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="The capital of France is",
|
||||
TopP=0.8,
|
||||
Tokens=50,
|
||||
Temperature=0.7,
|
||||
TopK=40,
|
||||
MinP=0.05,
|
||||
Seed=42,
|
||||
)
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("sampling params service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
Reference in New Issue
Block a user