mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 03:33:30 -05:00
Fix prompt for GLM
This commit is contained in:
@@ -258,6 +258,28 @@ class KVPrefixCache:
|
||||
return max_pressure
|
||||
|
||||
|
||||
def _fix_unmatched_think_end_tokens(tokens: mx.array, tokenizer: TokenizerWrapper) -> mx.array:
|
||||
if not tokenizer.has_thinking:
|
||||
return tokens
|
||||
think_start_id: int = tokenizer.think_start_id # type: ignore[attr-defined]
|
||||
think_end_id: int = tokenizer.think_end_id # type: ignore[attr-defined]
|
||||
token_list: list[int] = cast(list[int], tokens.tolist())
|
||||
result: list[int] = []
|
||||
depth = 0
|
||||
for token in token_list:
|
||||
if token == think_start_id:
|
||||
depth += 1
|
||||
elif token == think_end_id:
|
||||
if depth == 0:
|
||||
result.append(think_start_id)
|
||||
else:
|
||||
depth -= 1
|
||||
result.append(token)
|
||||
if len(result) == len(token_list):
|
||||
return tokens
|
||||
return mx.array(result)
|
||||
|
||||
|
||||
def trim_cache(
|
||||
cache: KVCacheType,
|
||||
num_tokens: int,
|
||||
@@ -282,7 +304,8 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
tokens = mx.array(tokenized_prompt)
|
||||
return _fix_unmatched_think_end_tokens(tokens, tokenizer)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
@@ -319,9 +342,7 @@ def make_kv_cache(
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(
|
||||
model, (GptOssModel, Qwen3NextModel)
|
||||
):
|
||||
if hasattr(model, "make_cache"):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
|
||||
Reference in New Issue
Block a user