Compare commits

...

8 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
06c7e157b9 try manual synchronization 2026-01-21 11:54:02 +00:00
Ryuichi Leo Takashige
f5886c1953 try placing in pipeline last layer 2026-01-21 11:35:25 +00:00
Ryuichi Leo Takashige
ded7b35a2e Eval everything 2026-01-21 11:21:28 +00:00
Ryuichi Leo Takashige
ff4fecf66a Debug 2026-01-21 11:19:16 +00:00
Ryuichi Leo Takashige
b12804d705 Try no mx_barrier 2026-01-21 11:15:57 +00:00
Ryuichi Leo Takashige
0ad852870a Use recv/send over all gather 2026-01-21 11:12:48 +00:00
Ryuichi Leo Takashige
ab2f3ac731 Revert "Fix small pipeline"
This reverts commit 4fec1ac7d6.
2026-01-20 20:58:23 +00:00
Ryuichi Leo Takashige
4fec1ac7d6 Fix small pipeline 2026-01-20 20:55:52 +00:00
2 changed files with 38 additions and 9 deletions

View File

@@ -113,7 +113,10 @@ class PipelineFirstLayer(CustomMlxLayer):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
logger.info(f"[PipelineFirstLayer] recv_like from rank {self.r - 1}, shape={x.shape}")
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
mx.eval(x)
logger.info(f"[PipelineFirstLayer] recv_like done, sum={x.sum().item():.4f}")
return self.original_layer(x, *args, **kwargs)
@@ -139,11 +142,41 @@ class PipelineLastLayer(CustomMlxLayer):
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
output = mx.distributed.send(
mx.eval(output)
logger.info(f"[PipelineLastLayer] send to rank {(self.r + 1) % self.s}, shape={output.shape}, sum={output.sum().item():.4f}")
sent = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
mx.eval(sent)
logger.info(f"[PipelineLastLayer] send done")
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
cache.keys = mx.depends(cache.keys, sent) # type: ignore[reportUnknownMemberType]
if self.r == self.s - 1:
mx.eval(output)
out_sum = output.sum().item()
logger.info(f"[patch_pipeline_model] send logits to rank 0, shape={output.shape}, sum={out_sum:.4f}")
output = mx.distributed.send(output, dst=0, group=self.group)
mx.eval(output)
logger.info(f"[patch_pipeline_model] send done")
output = mx.distributed.recv_like(output, src=0, group=self.group)
mx.eval(output)
recv_sum = output.sum().item()
logger.info(f"[patch_pipeline_model] recv_like_r done, sum={recv_sum:.4f}")
elif self.r == 0:
logger.info(f"[patch_pipeline_model] recv_like logits from rank {self.s - 1}, shape={output.shape}")
output = mx.distributed.recv_like(output, src=self.s - 1, group=self.group)
mx.eval(output)
recv_sum = output.sum().item()
logger.info(f"[patch_pipeline_model] recv_like done, sum={recv_sum:.4f}")
output = mx.distributed.send(output, dst=self.s - 1, group=self.group)
mx.eval(output)
logger.info(f"[patch_pipeline_model] send_r done")
# Synchronize before next iteration to ensure all JACCL operations complete
mx.synchronize()
logger.info(f"[PipelineLastLayer] iteration complete, synchronized")
return output
@@ -238,6 +271,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
original_call = cls.__call__ # type :ignore
call_signature = signature(original_call) # type :ignore
world_size = group.size()
rank = group.rank()
def patched_call(
self: T,
*args: object,
@@ -252,10 +288,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call

View File

@@ -286,9 +286,6 @@ def shard_and_load(
logger.debug("SHARDED")
logger.debug(model)
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
return model, tokenizer