mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-21 04:22:21 -05:00
Compare commits
11 Commits
aiohttp
...
leo/send-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2220eeb0c2 | ||
|
|
4dcdca7b71 | ||
|
|
11db77c2c3 | ||
|
|
5831c86178 | ||
|
|
d152cc6585 | ||
|
|
6078617212 | ||
|
|
1a1d676cdf | ||
|
|
1835e6f9fd | ||
|
|
00ffc45b0e | ||
|
|
880f31f53f | ||
|
|
cc3f4c6160 |
@@ -17,8 +17,8 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx[cpu]; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
@@ -59,6 +59,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx = { git = "https://github.com/ml-explore/mlx.git", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
|
||||
@@ -194,6 +194,7 @@ def pipeline_auto_parallel(
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
@@ -229,10 +230,10 @@ def pipeline_auto_parallel(
|
||||
"Expected a list of layers after auto-parallel initialisation"
|
||||
)
|
||||
|
||||
return patch_pipeline_model(model, group)
|
||||
return patch_distributed_model(model)
|
||||
|
||||
|
||||
def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
def patch_distributed_model[T](model: T) -> T:
|
||||
# Patch __call__ on the model's class
|
||||
cls = model.__class__
|
||||
original_call = cls.__call__ # type :ignore
|
||||
@@ -252,36 +253,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
return model
|
||||
|
||||
|
||||
def patch_tensor_model[T](model: T) -> T:
|
||||
"""Patch model's __call__ to ensure distributed ops sync during inference."""
|
||||
cls = model.__class__
|
||||
original_call = cls.__call__
|
||||
call_signature = signature(original_call)
|
||||
|
||||
def patched_call(
|
||||
self: T,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> mx.array:
|
||||
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
|
||||
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
|
||||
"cache", None
|
||||
)
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
@@ -334,14 +305,23 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
if isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
# elif hasattr(model, "shard"):
|
||||
# try:
|
||||
# model.shard(group) # type: ignore
|
||||
# return model
|
||||
# except (AttributeError, TypeError, NameError):
|
||||
# pass
|
||||
|
||||
elif isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
@@ -375,22 +355,13 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
model = tensor_parallel_sharding_strategy.shard_model(
|
||||
model, timeout_seconds, on_timeout
|
||||
)
|
||||
return patch_tensor_model(model)
|
||||
return patch_distributed_model(model)
|
||||
|
||||
|
||||
class TensorParallelShardingStrategy(ABC):
|
||||
|
||||
@@ -169,10 +169,10 @@ def mlx_distributed_init(
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
@@ -71,6 +71,7 @@ def main(
|
||||
bound_instance.bound_shard,
|
||||
)
|
||||
device_rank = shard_metadata.device_rank
|
||||
world_size = shard_metadata.world_size
|
||||
logger.info("hello from the runner")
|
||||
if getattr(shard_metadata, "immediate_exception", False):
|
||||
raise Exception("Fake exception - runner failed to spin up.")
|
||||
@@ -207,7 +208,7 @@ def main(
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if device_rank == 0:
|
||||
if device_rank == world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
|
||||
@@ -18,7 +18,7 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load
|
||||
|
||||
|
||||
class MockLayer(nn.Module):
|
||||
@@ -117,11 +117,11 @@ def run_gpt_oss_pipeline_device(
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
@@ -183,11 +183,11 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
model=model, tokenizer=tokenizer, task=task, prompt=prompt
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
|
||||
@@ -12,7 +12,7 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
CustomMlxLayer,
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
patch_pipeline_model,
|
||||
patch_distributed_model,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
|
||||
|
||||
@@ -55,7 +55,7 @@ def run_pipeline_device(
|
||||
|
||||
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
|
||||
inner_model = MockModel([composed])
|
||||
model = patch_pipeline_model(inner_model, group)
|
||||
model = patch_distributed_model(inner_model)
|
||||
|
||||
x = mx.ones((1, 4))
|
||||
result = model(x)
|
||||
@@ -138,9 +138,14 @@ def test_composed_call_works() -> None:
|
||||
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
|
||||
)
|
||||
result_array = results[rank]
|
||||
# Both devices see the final result (4.0) after all_gather
|
||||
assert (result_array == 4.0).all(), (
|
||||
f"Device {rank}: expected 4.0, got {result_array}"
|
||||
# Each device sees its local result: intermediate ranks return their
|
||||
# computed output (before sending), last rank returns the final result.
|
||||
# With world_size=2 and each layer doing x*2:
|
||||
# - Rank 0: 1.0 * 2 = 2.0 (sends to rank 1)
|
||||
# - Rank 1: 2.0 * 2 = 4.0 (last rank, final result)
|
||||
expected = 2.0 * (2**rank) # 2.0 for rank 0, 4.0 for rank 1
|
||||
assert (result_array == expected).all(), (
|
||||
f"Device {rank}: expected {expected}, got {result_array}"
|
||||
)
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
|
||||
Reference in New Issue
Block a user