mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-24 18:28:30 -05:00
Compare commits
5 Commits
v1.0.68
...
leo/fix-ma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e02b09c4a5 | ||
|
|
5dadb95662 | ||
|
|
9ea2ad2bc4 | ||
|
|
879e900c7c | ||
|
|
c4d707efbc |
@@ -129,10 +129,9 @@ class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
||||
if self.is_prefill:
|
||||
# We want to avoid GPU timeout errors by evalling the distributed operation
|
||||
# so that it stays on CPU, which does not have a timeout.
|
||||
mx.eval(x)
|
||||
# We want to avoid GPU timeout errors by evalling the distributed operation
|
||||
# so that it stays on CPU, which does not have a timeout.
|
||||
mx.eval(x)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -170,6 +170,7 @@ def warmup_inference(
|
||||
mx_barrier(group)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
set_pipeline_prefill(model, is_prefill=True)
|
||||
for _r in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -183,6 +184,7 @@ def warmup_inference(
|
||||
):
|
||||
logger.info("Generated warmup token: " + str(_r.text))
|
||||
tokens_generated += 1
|
||||
set_pipeline_prefill(model, is_prefill=False)
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
|
||||
@@ -214,6 +214,8 @@ def load_mlx_items(
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
mx.clear_cache()
|
||||
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user