Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
0c6e943cea replace model ID string matching with architecture-based checks in tokenizer loading
Use config.json model_type field instead of fragile substring matching on
model IDs for architecture detection in tokenizer loading. Falls back to
string matching when config.json is unavailable.

Closes #1371

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 10:19:09 -08:00
9 changed files with 71 additions and 54 deletions

View File

@@ -143,12 +143,7 @@ from exo.shared.types.openai_responses import (
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxDevice,
)
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -315,7 +310,6 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
mlx_device=payload.mlx_device,
)
await self._send(command)
@@ -356,7 +350,6 @@ class API:
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
mlx_device: MlxDevice = MlxDevice.Auto,
) -> Instance:
model_card = await ModelCard.load(model_id)
@@ -367,7 +360,6 @@ class API:
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
mlx_device=mlx_device,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,

View File

@@ -159,7 +159,6 @@ def place_instance(
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
mlx_device=command.mlx_device,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -174,7 +173,6 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
mlx_device=command.mlx_device,
)
return target_instances

View File

@@ -165,6 +165,7 @@ def is_custom_card(model_id: ModelId) -> bool:
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
model_type: str | None = None
architectures: list[str] | None = None
hidden_size: Annotated[int, Field(ge=0)] | None = None
layer_count: int = Field(
@@ -200,6 +201,7 @@ class ConfigData(BaseModel):
return data
for field in [
"model_type",
"architectures",
"hidden_size",
"num_hidden_layers",

View File

@@ -8,12 +8,7 @@ from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxDevice,
)
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
@@ -231,7 +226,6 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
mlx_device: MlxDevice = MlxDevice.Auto
class CreateInstanceParams(BaseModel):

View File

@@ -8,12 +8,7 @@ from exo.shared.types.api import (
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxDevice,
)
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -43,7 +38,6 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
mlx_device: MlxDevice = MlxDevice.Auto
class CreateInstance(BaseCommand):

View File

@@ -16,16 +16,9 @@ class InstanceMeta(str, Enum):
MlxJaccl = "MlxJaccl"
class MlxDevice(str, Enum):
Auto = "Auto"
Cpu = "Cpu"
Gpu = "Gpu"
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
mlx_device: MlxDevice = MlxDevice.Auto
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -269,19 +269,52 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
def _read_model_type_from_config(model_path: Path) -> str | None:
"""Read the model_type field from config.json at the given model path.
Some models require explicit EOS token configuration that isn't in their
tokenizer config. This function returns the known EOS token IDs for such models.
Returns None if config.json doesn't exist or doesn't contain model_type.
"""
config_path = model_path / "config.json"
if not config_path.exists():
return None
try:
with open(config_path) as f:
config: dict[str, Any] = json.load(f) # pyright: ignore[reportAny]
model_type: Any = config.get("model_type")
if model_type is None:
text_config: Any = config.get("text_config")
if isinstance(text_config, dict):
model_type = text_config.get("model_type") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
return model_type if isinstance(model_type, str) else None
except (json.JSONDecodeError, OSError):
return None
def get_eos_token_ids_for_model(
model_id: ModelId, model_type: str | None = None
) -> list[int] | None:
"""Get the EOS token IDs for a model based on its architecture type.
Uses model_type from config.json when available, falls back to model_id
string matching for backward compatibility.
Args:
model_id: The HuggingFace model ID
model_type: The model_type field from config.json (e.g., "kimi", "glm4")
Returns:
List of EOS token IDs, or None if the model uses standard tokenizer config
"""
if model_type is not None:
if model_type == "kimi":
return [163586]
elif model_type == "glm4_moe_lite":
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif model_type.startswith("glm"):
return [151336, 151329, 151338]
# Fallback: string matching on model_id
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
@@ -296,11 +329,10 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
"""Load tokenizer for a model given its ID and local path.
This is the core tokenizer loading logic, handling special cases for different
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
Uses model_type from config.json for architecture detection when available,
falling back to model_id string matching for backward compatibility.
Args:
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
@@ -309,11 +341,21 @@ def load_tokenizer_for_model_id(
Returns:
TokenizerWrapper instance configured for the model
"""
model_type = _read_model_type_from_config(model_path)
model_id_lower = model_id.lower()
eos_token_ids = get_eos_token_ids_for_model(model_id)
eos_token_ids = get_eos_token_ids_for_model(model_id, model_type=model_type)
is_kimi = (
model_type == "kimi" if model_type is not None else "kimi-k2" in model_id_lower
)
is_gemma3 = (
model_type == "gemma3"
if model_type is not None
else "gemma-3" in model_id_lower
)
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
if "kimi-k2" in model_id_lower:
if is_kimi:
import importlib.util
import types
@@ -367,7 +409,7 @@ def load_tokenizer_for_model_id(
eos_token_ids=eos_token_ids,
)
if "gemma-3" in model_id_lower:
if is_gemma3:
gemma_3_eos_id = 1
gemma_3_end_of_turn_id = 106
if tokenizer.eos_token_ids is not None:

View File

@@ -4,7 +4,7 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxDevice, MlxJacclInstance
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -35,15 +35,6 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Set MLX compute device before importing runner (which imports mlx.core at module scope)
mlx_device = bound_instance.instance.mlx_device
if mlx_device != MlxDevice.Auto:
import mlx.core as mx
device = mx.cpu if mlx_device == MlxDevice.Cpu else mx.gpu
mx.set_default_device(device)
logger.info(f"MLX device set to: {mlx_device}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main

View File

@@ -24,6 +24,7 @@ from exo.worker.engines.mlx.utils_mlx import (
# Files needed for tokenizer functionality
TOKENIZER_FILE_PATTERNS = [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
@@ -338,6 +339,9 @@ async def test_kimi_tokenizer_specifically():
# Verify EOS token is set
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
# Verify architecture-based detection gives same result
assert get_eos_token_ids_for_model(model_id, model_type="kimi") == [163586]
# Test GLM tokenizer since it also has special handling
@pytest.mark.asyncio
@@ -378,3 +382,10 @@ async def test_glm_tokenizer_specifically():
151329,
151338,
], "GLM EOS tokens should be correct"
# Verify architecture-based detection gives same result
assert get_eos_token_ids_for_model(model_id, model_type="glm4") == [
151336,
151329,
151338,
]