mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 14:55:13 -05:00
Compare commits
1 Commits
support-ml
...
architectu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c6e943cea |
@@ -143,12 +143,7 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
@@ -315,7 +310,6 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
mlx_device=payload.mlx_device,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
@@ -356,7 +350,6 @@ class API:
|
||||
sharding: Sharding = Sharding.Pipeline,
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
mlx_device: MlxDevice = MlxDevice.Auto,
|
||||
) -> Instance:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
|
||||
@@ -367,7 +360,6 @@ class API:
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
mlx_device=mlx_device,
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
|
||||
@@ -159,7 +159,6 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
mlx_device=command.mlx_device,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -174,7 +173,6 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
mlx_device=command.mlx_device,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -165,6 +165,7 @@ def is_custom_card(model_id: ModelId) -> bool:
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
model_type: str | None = None
|
||||
architectures: list[str] | None = None
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
layer_count: int = Field(
|
||||
@@ -200,6 +201,7 @@ class ConfigData(BaseModel):
|
||||
return data
|
||||
|
||||
for field in [
|
||||
"model_type",
|
||||
"architectures",
|
||||
"hidden_size",
|
||||
"num_hidden_layers",
|
||||
|
||||
@@ -8,12 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -231,7 +226,6 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
|
||||
class CreateInstanceParams(BaseModel):
|
||||
|
||||
@@ -8,12 +8,7 @@ from exo.shared.types.api import (
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxDevice,
|
||||
)
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -43,7 +38,6 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
|
||||
@@ -16,16 +16,9 @@ class InstanceMeta(str, Enum):
|
||||
MlxJaccl = "MlxJaccl"
|
||||
|
||||
|
||||
class MlxDevice(str, Enum):
|
||||
Auto = "Auto"
|
||||
Cpu = "Cpu"
|
||||
Gpu = "Gpu"
|
||||
|
||||
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
mlx_device: MlxDevice = MlxDevice.Auto
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -269,19 +269,52 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
"""
|
||||
Get the EOS token IDs for a model based on its ID.
|
||||
def _read_model_type_from_config(model_path: Path) -> str | None:
|
||||
"""Read the model_type field from config.json at the given model path.
|
||||
|
||||
Some models require explicit EOS token configuration that isn't in their
|
||||
tokenizer config. This function returns the known EOS token IDs for such models.
|
||||
Returns None if config.json doesn't exist or doesn't contain model_type.
|
||||
"""
|
||||
config_path = model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config: dict[str, Any] = json.load(f) # pyright: ignore[reportAny]
|
||||
model_type: Any = config.get("model_type")
|
||||
if model_type is None:
|
||||
text_config: Any = config.get("text_config")
|
||||
if isinstance(text_config, dict):
|
||||
model_type = text_config.get("model_type") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
return model_type if isinstance(model_type, str) else None
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(
|
||||
model_id: ModelId, model_type: str | None = None
|
||||
) -> list[int] | None:
|
||||
"""Get the EOS token IDs for a model based on its architecture type.
|
||||
|
||||
Uses model_type from config.json when available, falls back to model_id
|
||||
string matching for backward compatibility.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID
|
||||
model_type: The model_type field from config.json (e.g., "kimi", "glm4")
|
||||
|
||||
Returns:
|
||||
List of EOS token IDs, or None if the model uses standard tokenizer config
|
||||
"""
|
||||
if model_type is not None:
|
||||
if model_type == "kimi":
|
||||
return [163586]
|
||||
elif model_type == "glm4_moe_lite":
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif model_type.startswith("glm"):
|
||||
return [151336, 151329, 151338]
|
||||
|
||||
# Fallback: string matching on model_id
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
@@ -296,11 +329,10 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
def load_tokenizer_for_model_id(
|
||||
model_id: ModelId, model_path: Path
|
||||
) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
"""Load tokenizer for a model given its ID and local path.
|
||||
|
||||
This is the core tokenizer loading logic, handling special cases for different
|
||||
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
|
||||
Uses model_type from config.json for architecture detection when available,
|
||||
falling back to model_id string matching for backward compatibility.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
|
||||
@@ -309,11 +341,21 @@ def load_tokenizer_for_model_id(
|
||||
Returns:
|
||||
TokenizerWrapper instance configured for the model
|
||||
"""
|
||||
model_type = _read_model_type_from_config(model_path)
|
||||
model_id_lower = model_id.lower()
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id, model_type=model_type)
|
||||
|
||||
is_kimi = (
|
||||
model_type == "kimi" if model_type is not None else "kimi-k2" in model_id_lower
|
||||
)
|
||||
is_gemma3 = (
|
||||
model_type == "gemma3"
|
||||
if model_type is not None
|
||||
else "gemma-3" in model_id_lower
|
||||
)
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if "kimi-k2" in model_id_lower:
|
||||
if is_kimi:
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
@@ -367,7 +409,7 @@ def load_tokenizer_for_model_id(
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
if "gemma-3" in model_id_lower:
|
||||
if is_gemma3:
|
||||
gemma_3_eos_id = 1
|
||||
gemma_3_end_of_turn_id = 106
|
||||
if tokenizer.eos_token_ids is not None:
|
||||
|
||||
@@ -4,7 +4,7 @@ import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxDevice, MlxJacclInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
@@ -35,15 +35,6 @@ def entrypoint(
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Set MLX compute device before importing runner (which imports mlx.core at module scope)
|
||||
mlx_device = bound_instance.instance.mlx_device
|
||||
if mlx_device != MlxDevice.Auto:
|
||||
import mlx.core as mx
|
||||
|
||||
device = mx.cpu if mlx_device == MlxDevice.Cpu else mx.gpu
|
||||
mx.set_default_device(device)
|
||||
logger.info(f"MLX device set to: {mlx_device}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
@@ -24,6 +24,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
|
||||
# Files needed for tokenizer functionality
|
||||
TOKENIZER_FILE_PATTERNS = [
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
@@ -338,6 +339,9 @@ async def test_kimi_tokenizer_specifically():
|
||||
# Verify EOS token is set
|
||||
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
|
||||
|
||||
# Verify architecture-based detection gives same result
|
||||
assert get_eos_token_ids_for_model(model_id, model_type="kimi") == [163586]
|
||||
|
||||
|
||||
# Test GLM tokenizer since it also has special handling
|
||||
@pytest.mark.asyncio
|
||||
@@ -378,3 +382,10 @@ async def test_glm_tokenizer_specifically():
|
||||
151329,
|
||||
151338,
|
||||
], "GLM EOS tokens should be correct"
|
||||
|
||||
# Verify architecture-based detection gives same result
|
||||
assert get_eos_token_ids_for_model(model_id, model_type="glm4") == [
|
||||
151336,
|
||||
151329,
|
||||
151338,
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user