From bce8ee3a6ec90da43dbefaec6ebfbe1426da8624 Mon Sep 17 00:00:00 2001 From: Ryuichi Leo Takashige Date: Wed, 4 Feb 2026 19:18:25 +0000 Subject: [PATCH] Force synchronization points --- src/exo/worker/engines/mlx/generator/generate.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 736a11c3..9afbad6d 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -116,6 +116,7 @@ def prefill( def warmup_inference( model: Model, tokenizer: TokenizerWrapper, + group: mx.distributed.Group | None = None, ) -> int: content = "Prompt to warm up the inference engine. Repeat this." @@ -153,9 +154,7 @@ def warmup_inference( logger.info("Generated ALL warmup tokens") - # TODO: Do we want an mx_barrier? - # At least this version is actively incorrect, as it should use mx_barrier(group) - mx_barrier() + mx_barrier(group) return tokens_generated @@ -184,6 +183,7 @@ def mlx_generate( task: TextGenerationTaskParams, prompt: str, kv_prefix_cache: KVPrefixCache | None = None, + group: mx.distributed.Group | None = None, ) -> Generator[GenerationResponse]: # Ensure that generation stats only contains peak memory for this generation mx.reset_peak_memory() @@ -366,10 +366,9 @@ def mlx_generate( ) if is_done: + mx_barrier(group) break # Limit accumulated_text to what's needed for stop sequence detection if max_stop_len > 0 and len(accumulated_text) > max_stop_len: accumulated_text = accumulated_text[-max_stop_len:] - - # TODO: Do we want an mx_barrier?