Compare commits

...

5 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
e02b09c4a5 undo 2026-02-24 23:26:22 +00:00
Ryuichi Leo Takashige
5dadb95662 try without setting pipeline prefill on wamrup 2026-02-24 23:24:07 +00:00
Ryuichi Leo Takashige
9ea2ad2bc4 always eval the recv 2026-02-24 23:23:16 +00:00
Ryuichi Leo Takashige
879e900c7c try setting pipeline prefill on warmup generate (will be gibberish) 2026-02-24 23:00:34 +00:00
Ryuichi Leo Takashige
c4d707efbc clear cache after model load 2026-02-24 22:52:05 +00:00
3 changed files with 7 additions and 4 deletions

View File

@@ -129,10 +129,9 @@ class PipelineFirstLayer(CustomMlxLayer):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
if self.is_prefill:
# We want to avoid GPU timeout errors by evalling the distributed operation
# so that it stays on CPU, which does not have a timeout.
mx.eval(x)
# We want to avoid GPU timeout errors by evalling the distributed operation
# so that it stays on CPU, which does not have a timeout.
mx.eval(x)
return self.original_layer(x, *args, **kwargs)

View File

@@ -170,6 +170,7 @@ def warmup_inference(
mx_barrier(group)
logger.info("Generating warmup tokens")
set_pipeline_prefill(model, is_prefill=True)
for _r in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -183,6 +184,7 @@ def warmup_inference(
):
logger.info("Generated warmup token: " + str(_r.text))
tokens_generated += 1
set_pipeline_prefill(model, is_prefill=False)
logger.info("Generated ALL warmup tokens")

View File

@@ -214,6 +214,8 @@ def load_mlx_items(
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
mx.clear_cache()
return cast(Model, model), tokenizer