mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 21:41:21 -05:00
Compare commits
8 Commits
ciaran/ima
...
leo/fix-sm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06c7e157b9 | ||
|
|
f5886c1953 | ||
|
|
ded7b35a2e | ||
|
|
ff4fecf66a | ||
|
|
b12804d705 | ||
|
|
0ad852870a | ||
|
|
ab2f3ac731 | ||
|
|
4fec1ac7d6 |
@@ -113,7 +113,10 @@ class PipelineFirstLayer(CustomMlxLayer):
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
logger.info(f"[PipelineFirstLayer] recv_like from rank {self.r - 1}, shape={x.shape}")
|
||||
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
||||
mx.eval(x)
|
||||
logger.info(f"[PipelineFirstLayer] recv_like done, sum={x.sum().item():.4f}")
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
@@ -139,11 +142,41 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
mx.eval(output)
|
||||
logger.info(f"[PipelineLastLayer] send to rank {(self.r + 1) % self.s}, shape={output.shape}, sum={output.sum().item():.4f}")
|
||||
sent = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
mx.eval(sent)
|
||||
logger.info(f"[PipelineLastLayer] send done")
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
cache.keys = mx.depends(cache.keys, sent) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
if self.r == self.s - 1:
|
||||
mx.eval(output)
|
||||
out_sum = output.sum().item()
|
||||
logger.info(f"[patch_pipeline_model] send logits to rank 0, shape={output.shape}, sum={out_sum:.4f}")
|
||||
output = mx.distributed.send(output, dst=0, group=self.group)
|
||||
mx.eval(output)
|
||||
logger.info(f"[patch_pipeline_model] send done")
|
||||
output = mx.distributed.recv_like(output, src=0, group=self.group)
|
||||
mx.eval(output)
|
||||
recv_sum = output.sum().item()
|
||||
logger.info(f"[patch_pipeline_model] recv_like_r done, sum={recv_sum:.4f}")
|
||||
|
||||
elif self.r == 0:
|
||||
logger.info(f"[patch_pipeline_model] recv_like logits from rank {self.s - 1}, shape={output.shape}")
|
||||
output = mx.distributed.recv_like(output, src=self.s - 1, group=self.group)
|
||||
mx.eval(output)
|
||||
recv_sum = output.sum().item()
|
||||
logger.info(f"[patch_pipeline_model] recv_like done, sum={recv_sum:.4f}")
|
||||
output = mx.distributed.send(output, dst=self.s - 1, group=self.group)
|
||||
mx.eval(output)
|
||||
logger.info(f"[patch_pipeline_model] send_r done")
|
||||
|
||||
# Synchronize before next iteration to ensure all JACCL operations complete
|
||||
mx.synchronize()
|
||||
logger.info(f"[PipelineLastLayer] iteration complete, synchronized")
|
||||
|
||||
return output
|
||||
|
||||
@@ -238,6 +271,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
original_call = cls.__call__ # type :ignore
|
||||
call_signature = signature(original_call) # type :ignore
|
||||
|
||||
world_size = group.size()
|
||||
rank = group.rank()
|
||||
|
||||
def patched_call(
|
||||
self: T,
|
||||
*args: object,
|
||||
@@ -252,10 +288,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
|
||||
@@ -286,9 +286,6 @@ def shard_and_load(
|
||||
logger.debug("SHARDED")
|
||||
logger.debug(model)
|
||||
|
||||
# Synchronize processes before generation to avoid timeout
|
||||
mx_barrier(group)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user