diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index 3ec42ac4..354a1aed 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -258,6 +258,28 @@ class KVPrefixCache: return max_pressure +def _fix_unmatched_think_end_tokens(tokens: mx.array, tokenizer: TokenizerWrapper) -> mx.array: + if not tokenizer.has_thinking: + return tokens + think_start_id: int = tokenizer.think_start_id # type: ignore[attr-defined] + think_end_id: int = tokenizer.think_end_id # type: ignore[attr-defined] + token_list: list[int] = cast(list[int], tokens.tolist()) + result: list[int] = [] + depth = 0 + for token in token_list: + if token == think_start_id: + depth += 1 + elif token == think_end_id: + if depth == 0: + result.append(think_start_id) + else: + depth -= 1 + result.append(token) + if len(result) == len(token_list): + return tokens + return mx.array(result) + + def trim_cache( cache: KVCacheType, num_tokens: int, @@ -282,7 +304,8 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array: """ # Chat templates define their own structure - don't add BOS/EOS tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False) - return mx.array(tokenized_prompt) + tokens = mx.array(tokenized_prompt) + return _fix_unmatched_think_end_tokens(tokens, tokenizer) def cache_length(cache: KVCacheType) -> int: @@ -319,9 +342,7 @@ def make_kv_cache( assert hasattr(model, "layers") # TODO: Do this for all models - if hasattr(model, "make_cache") and isinstance( - model, (GptOssModel, Qwen3NextModel) - ): + if hasattr(model, "make_cache"): logger.info("Using MLX LM's make cache") return model.make_cache() # type: ignore