Fix prompt for GLM

This commit is contained in:
Ryuichi Leo Takashige
2026-02-04 19:53:53 +00:00
parent bce8ee3a6e
commit fe05608260

View File

@@ -258,6 +258,28 @@ class KVPrefixCache:
return max_pressure
def _fix_unmatched_think_end_tokens(tokens: mx.array, tokenizer: TokenizerWrapper) -> mx.array:
if not tokenizer.has_thinking:
return tokens
think_start_id: int = tokenizer.think_start_id # type: ignore[attr-defined]
think_end_id: int = tokenizer.think_end_id # type: ignore[attr-defined]
token_list: list[int] = cast(list[int], tokens.tolist())
result: list[int] = []
depth = 0
for token in token_list:
if token == think_start_id:
depth += 1
elif token == think_end_id:
if depth == 0:
result.append(think_start_id)
else:
depth -= 1
result.append(token)
if len(result) == len(token_list):
return tokens
return mx.array(result)
def trim_cache(
cache: KVCacheType,
num_tokens: int,
@@ -282,7 +304,8 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
tokens = mx.array(tokenized_prompt)
return _fix_unmatched_think_end_tokens(tokens, tokenizer)
def cache_length(cache: KVCacheType) -> int:
@@ -319,9 +342,7 @@ def make_kv_cache(
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(
model, (GptOssModel, Qwen3NextModel)
):
if hasattr(model, "make_cache"):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore