Compare commits

...

11 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
2220eeb0c2 No more upstream stuff ig 2026-01-20 14:49:09 +00:00
Ryuichi Leo Takashige
4dcdca7b71 Fix? 2026-01-20 14:44:40 +00:00
Ryuichi Leo Takashige
11db77c2c3 Use upstream mlx 2026-01-20 14:14:05 +00:00
Ryuichi Leo Takashige
5831c86178 Fix if statement... 2026-01-20 14:11:16 +00:00
Ryuichi Leo Takashige
d152cc6585 Revert "Use upstream mlx"
This reverts commit 1a1d676cdf.
2026-01-20 14:08:53 +00:00
Ryuichi Leo Takashige
6078617212 I did a dumb dumb 2026-01-20 14:07:56 +00:00
Ryuichi Leo Takashige
1a1d676cdf Use upstream mlx 2026-01-20 13:08:01 +00:00
Ryuichi Leo Takashige
1835e6f9fd GPT OSS is still broken upstream 2026-01-20 12:59:48 +00:00
Ryuichi Leo Takashige
00ffc45b0e Try this 2026-01-20 12:38:36 +00:00
Ryuichi Leo Takashige
880f31f53f Only patch for certain models 2026-01-20 12:07:52 +00:00
Ryuichi Leo Takashige
cc3f4c6160 Send chunks from the last rank 2026-01-20 11:57:26 +00:00
6 changed files with 44 additions and 66 deletions

View File

@@ -17,8 +17,8 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx; sys_platform == 'darwin'",
"mlx[cpu]; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
@@ -59,6 +59,7 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/ml-explore/mlx.git", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }

View File

@@ -194,6 +194,7 @@ def pipeline_auto_parallel(
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],
@@ -229,10 +230,10 @@ def pipeline_auto_parallel(
"Expected a list of layers after auto-parallel initialisation"
)
return patch_pipeline_model(model, group)
return patch_distributed_model(model)
def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
def patch_distributed_model[T](model: T) -> T:
# Patch __call__ on the model's class
cls = model.__class__
original_call = cls.__call__ # type :ignore
@@ -252,36 +253,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call
return model
def patch_tensor_model[T](model: T) -> T:
"""Patch model's __call__ to ensure distributed ops sync during inference."""
cls = model.__class__
original_call = cls.__call__
call_signature = signature(original_call)
def patched_call(
self: T,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
cls.__call__ = patched_call
@@ -334,14 +305,23 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
if isinstance(model, (LlamaModel, Ministral3Model)):
# elif hasattr(model, "shard"):
# try:
# model.shard(group) # type: ignore
# return model
# except (AttributeError, TypeError, NameError):
# pass
elif isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
@@ -375,22 +355,13 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
return patch_distributed_model(model)
class TensorParallelShardingStrategy(ABC):

View File

@@ -169,10 +169,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)

View File

@@ -71,6 +71,7 @@ def main(
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
world_size = shard_metadata.world_size
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
@@ -207,7 +208,7 @@ def main(
for response in mlx_generator:
match response:
case GenerationResponse():
if device_rank == 0:
if device_rank == world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,

View File

@@ -18,7 +18,7 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
from exo.worker.engines.mlx.utils_mlx import shard_and_load
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load
class MockLayer(nn.Module):
@@ -117,11 +117,11 @@ def run_gpt_oss_pipeline_device(
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
model=model, tokenizer=tokenizer, task=task, prompt=prompt
):
generated_text += response.text
if response.finish_reason is not None:
@@ -183,11 +183,11 @@ def run_gpt_oss_tensor_parallel_device(
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
generated_text = ""
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
model=model, tokenizer=tokenizer, task=task, prompt=prompt
):
generated_text += response.text
if response.finish_reason is not None:

View File

@@ -12,7 +12,7 @@ from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer,
PipelineFirstLayer,
PipelineLastLayer,
patch_pipeline_model,
patch_distributed_model,
)
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
@@ -55,7 +55,7 @@ def run_pipeline_device(
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
inner_model = MockModel([composed])
model = patch_pipeline_model(inner_model, group)
model = patch_distributed_model(inner_model)
x = mx.ones((1, 4))
result = model(x)
@@ -138,9 +138,14 @@ def test_composed_call_works() -> None:
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
)
result_array = results[rank]
# Both devices see the final result (4.0) after all_gather
assert (result_array == 4.0).all(), (
f"Device {rank}: expected 4.0, got {result_array}"
# Each device sees its local result: intermediate ranks return their
# computed output (before sending), last rank returns the final result.
# With world_size=2 and each layer doing x*2:
# - Rank 0: 1.0 * 2 = 2.0 (sends to rank 1)
# - Rank 1: 2.0 * 2 = 4.0 (last rank, final result)
expected = 2.0 * (2**rank) # 2.0 for rank 0, 4.0 for rank 1
assert (result_array == expected).all(), (
f"Device {rank}: expected {expected}, got {result_array}"
)
finally:
os.unlink(hostfile_path)