mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-20 20:10:10 -05:00
Compare commits
7 Commits
tool-calli
...
model-card
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92b24196c3 | ||
|
|
3bf7770988 | ||
|
|
8392463a70 | ||
|
|
9c1f6224b0 | ||
|
|
f370dbd1e0 | ||
|
|
6a38f9efba | ||
|
|
0475de6431 |
@@ -10,6 +10,7 @@ PROJECT_ROOT = Path.cwd()
|
||||
SOURCE_ROOT = PROJECT_ROOT / "src"
|
||||
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
|
||||
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
|
||||
RESOURCES_DIR = PROJECT_ROOT / "resources"
|
||||
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
|
||||
|
||||
if not ENTRYPOINT.is_file():
|
||||
@@ -18,6 +19,9 @@ if not ENTRYPOINT.is_file():
|
||||
if not DASHBOARD_DIR.is_dir():
|
||||
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
|
||||
|
||||
if not RESOURCES_DIR.is_dir():
|
||||
raise SystemExit(f"Resources are missing: {RESOURCES_DIR}")
|
||||
|
||||
if not EXO_SHARED_MODELS_DIR.is_dir():
|
||||
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
|
||||
|
||||
@@ -58,6 +62,7 @@ HIDDEN_IMPORTS = sorted(
|
||||
|
||||
DATAS: list[tuple[str, str]] = [
|
||||
(str(DASHBOARD_DIR), "dashboard"),
|
||||
(str(RESOURCES_DIR), "resources"),
|
||||
(str(MLX_LIB_DIR), "mlx/lib"),
|
||||
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
|
||||
]
|
||||
|
||||
@@ -24,7 +24,6 @@ dependencies = [
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
7
resources/mlx-community--DeepSeek-V3.1-4bit.toml
Normal file
7
resources/mlx-community--DeepSeek-V3.1-4bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
7
resources/mlx-community--DeepSeek-V3.1-8bit.toml
Normal file
7
resources/mlx-community--DeepSeek-V3.1-8bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
7
resources/mlx-community--GLM-4.5-Air-8bit.toml
Normal file
7
resources/mlx-community--GLM-4.5-Air-8bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-8bit"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = false
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
7
resources/mlx-community--GLM-4.5-Air-bf16.toml
Normal file
7
resources/mlx-community--GLM-4.5-Air-bf16.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-bf16"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
7
resources/mlx-community--GLM-4.7-4bit.toml
Normal file
7
resources/mlx-community--GLM-4.7-4bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/GLM-4.7-4bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
7
resources/mlx-community--GLM-4.7-6bit.toml
Normal file
7
resources/mlx-community--GLM-4.7-6bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/GLM-4.7-6bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
7
resources/mlx-community--GLM-4.7-8bit-gs32.toml
Normal file
7
resources/mlx-community--GLM-4.7-8bit-gs32.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
7
resources/mlx-community--Kimi-K2-Instruct-4bit.toml
Normal file
7
resources/mlx-community--Kimi-K2-Instruct-4bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 620622774272
|
||||
7
resources/mlx-community--Kimi-K2-Thinking.toml
Normal file
7
resources/mlx-community--Kimi-K2-Thinking.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Kimi-K2-Thinking"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
7
resources/mlx-community--Llama-3.2-1B-Instruct-4bit.toml
Normal file
7
resources/mlx-community--Llama-3.2-1B-Instruct-4bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
n_layers = 16
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 729808896
|
||||
7
resources/mlx-community--Llama-3.2-3B-Instruct-4bit.toml
Normal file
7
resources/mlx-community--Llama-3.2-3B-Instruct-4bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1863319552
|
||||
7
resources/mlx-community--Llama-3.2-3B-Instruct-8bit.toml
Normal file
7
resources/mlx-community--Llama-3.2-3B-Instruct-8bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 3501195264
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 76799803392
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 4637851648
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 8954839040
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 16882073600
|
||||
7
resources/mlx-community--MiniMax-M2.1-3bit.toml
Normal file
7
resources/mlx-community--MiniMax-M2.1-3bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-3bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
7
resources/mlx-community--MiniMax-M2.1-8bit.toml
Normal file
7
resources/mlx-community--MiniMax-M2.1-8bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
7
resources/mlx-community--Qwen3-0.6B-4bit.toml
Normal file
7
resources/mlx-community--Qwen3-0.6B-4bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
7
resources/mlx-community--Qwen3-0.6B-8bit.toml
Normal file
7
resources/mlx-community--Qwen3-0.6B-8bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
7
resources/mlx-community--Qwen3-30B-A3B-4bit.toml
Normal file
7
resources/mlx-community--Qwen3-30B-A3B-4bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
7
resources/mlx-community--Qwen3-30B-A3B-8bit.toml
Normal file
7
resources/mlx-community--Qwen3-30B-A3B-8bit.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 289910292480
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 579820584960
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 46976204800
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
7
resources/mlx-community--gpt-oss-120b-MXFP4-Q8.toml
Normal file
7
resources/mlx-community--gpt-oss-120b-MXFP4-Q8.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
n_layers = 36
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 70652212224
|
||||
7
resources/mlx-community--gpt-oss-20b-MXFP4-Q8.toml
Normal file
7
resources/mlx-community--gpt-oss-20b-MXFP4-Q8.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
n_layers = 24
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 12025908224
|
||||
@@ -0,0 +1,7 @@
|
||||
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 144383672320
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
@@ -19,11 +20,7 @@ from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import (
|
||||
MODEL_CARDS,
|
||||
ModelCard,
|
||||
ModelId,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
@@ -44,7 +41,7 @@ from exo.shared.types.api import (
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
@@ -68,12 +65,12 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.dashboard_path import RuntimeResources, find_directory
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -85,72 +82,56 @@ def chunk_to_response(
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
if isinstance(chunk, TokenChunk)
|
||||
else StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[tool.model_dump() for tool in chunk.tool_calls],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
else:
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
port: int
|
||||
app: FastAPI
|
||||
global_event_receiver: Receiver[ForwarderEvent]
|
||||
command_sender: Sender[ForwarderCommand]
|
||||
election_receiver: Receiver[ElectionMessage]
|
||||
state = field(init=False, default_factory=State)
|
||||
_event_log: list[Event] = field(init=False, default_factory=list)
|
||||
event_buffer: OrderedBuffer[Event] = field(init=False, default_factory=OrderedBuffer)
|
||||
_chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = field(init=False, default_factory=dict)
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
last_completed_election: int = field(init=False, default=0)
|
||||
paused: bool = field(init=False, default = False)
|
||||
paused_ev: anyio.Event = field(init=False, default_factory=anyio.Event)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
*,
|
||||
port: int,
|
||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log: list[Event] = []
|
||||
self.command_sender = command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
self.last_completed_election: int = 0
|
||||
self.port = port
|
||||
|
||||
self.paused: bool = False
|
||||
self.paused_ev: anyio.Event = anyio.Event()
|
||||
|
||||
self.app = FastAPI()
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
self.app.mount(
|
||||
app = FastAPI()
|
||||
app.mount(
|
||||
"/",
|
||||
StaticFiles(
|
||||
directory=find_dashboard(),
|
||||
directory=await find_directory(RuntimeResources.Dashboard),
|
||||
html=True,
|
||||
),
|
||||
name="dashboard",
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ToolCallChunk]
|
||||
] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
cls(node_id, session_id, port, app, global_event_receiver, command_sender, election_receiver)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
logger.info("Resetting API State")
|
||||
@@ -227,7 +208,7 @@ class API:
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
instance = payload.instance
|
||||
model_card = await resolve_model_card(instance.shard_assignments.model_id)
|
||||
model_card = await ModelCard.from_hf(instance.shard_assignments.model_id)
|
||||
required_memory = model_card.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
@@ -250,7 +231,7 @@ class API:
|
||||
|
||||
async def get_placement(
|
||||
self,
|
||||
model_id: ModelId,
|
||||
model_id: str,
|
||||
sharding: Sharding = Sharding.Pipeline,
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
@@ -293,7 +274,7 @@ class API:
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
cards = [card for card in await get_model_cards() if card.short_id == model_id]
|
||||
if not cards:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
|
||||
|
||||
@@ -415,21 +396,16 @@ class API:
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk | ToolCallChunk, None]:
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
TokenChunk | ToolCallChunk
|
||||
]()
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if (
|
||||
isinstance(chunk, TokenChunk)
|
||||
and chunk.finish_reason is not None
|
||||
):
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
@@ -451,7 +427,7 @@ class API:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, TokenChunk) and chunk.finish_reason == "error":
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
@@ -470,7 +446,7 @@ class API:
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if isinstance(chunk, ToolCallChunk) or chunk.finish_reason is not None:
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
@@ -483,24 +459,6 @@ class API:
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
finish_reason = "tool_calls"
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model or chunk.model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
tool.model_dump() for tool in chunk.tool_calls
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -544,11 +502,6 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Tool call in bench",
|
||||
)
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -593,7 +546,7 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -620,7 +573,7 @@ class API:
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -662,7 +615,7 @@ class API:
|
||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
||||
supports_tensor=card.supports_tensor,
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
for card in model_cards()
|
||||
]
|
||||
)
|
||||
|
||||
@@ -699,7 +652,7 @@ class API:
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, (TokenChunk, ToolCallChunk))
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
|
||||
@@ -276,7 +276,9 @@ def test_placement_selects_leaf_nodes(
|
||||
# arrange
|
||||
topology = Topology()
|
||||
|
||||
model_card.storage_size = Memory.from_bytes(1000)
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
model_card.n_layers = 12
|
||||
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
|
||||
@@ -6,14 +6,29 @@ import tomlkit
|
||||
from anyio import Path, open_file
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import BaseModel, Field, PositiveInt, ValidationError
|
||||
from tomlkit.exceptions import TOMLKitError
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import Id
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.dashboard_path import RuntimeResources, find_directory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
|
||||
_card_cache: dict[str, "ModelCard"] = {}
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
def short(self) -> str:
|
||||
return self.split("/")[-1]
|
||||
|
||||
_card_cache: dict[str, ModelCard] = {}
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
@@ -28,22 +43,32 @@ class ModelCard(CamelCaseModel):
|
||||
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
|
||||
await f.write(data)
|
||||
|
||||
async def save_to_default_path(self) -> None:
|
||||
dir = await find_directory(RuntimeResources.Resources)
|
||||
await self.save(dir / self.model_id.normalize())
|
||||
|
||||
@staticmethod
|
||||
async def load_from_path(path: Path) -> "ModelCard":
|
||||
async def load_from_path(path: Path) -> ModelCard:
|
||||
async with await open_file(path, "r") as f:
|
||||
py = tomlkit.loads(await f.read())
|
||||
return ModelCard.model_validate(py)
|
||||
|
||||
@staticmethod
|
||||
async def load(model_id: ModelId) -> "ModelCard":
|
||||
if model_id in MODEL_CARDS:
|
||||
return MODEL_CARDS[model_id]
|
||||
return await ModelCard.from_hf(model_id)
|
||||
async def load_from_default_path(model_id: ModelId) -> ModelCard:
|
||||
return await ModelCard.load_from_path(await find_directory(RuntimeResources.Resources) / model_id.normalize())
|
||||
|
||||
@staticmethod
|
||||
async def from_hf(model_id: ModelId) -> "ModelCard":
|
||||
async def load(model_id: ModelId) -> ModelCard:
|
||||
try:
|
||||
return await ModelCard.load_from_default_path(model_id)
|
||||
except (ValidationError, TOMLKitError, FileNotFoundError):
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def from_hf(model_id: ModelId) -> ModelCard:
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
if (mc := _card_cache.get(model_id)) is not None:
|
||||
if (mc := _card_cache.get(model_id, None)) is not None:
|
||||
return mc
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
@@ -54,254 +79,25 @@ class ModelCard(CamelCaseModel):
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
supports_tensor=config_data.supports_tensor,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=False,
|
||||
)
|
||||
_card_cache[model_id] = mc
|
||||
return mc
|
||||
|
||||
# TODO: should we cache this? how do we check for changes
|
||||
async def get_model_cards() -> list[ModelCard]:
|
||||
dir = await find_directory(RuntimeResources.Resources)
|
||||
cards: list[ModelCard] = []
|
||||
async for file in dir.glob("*.toml"):
|
||||
try:
|
||||
cards.append(await ModelCard.load_from_path(file))
|
||||
except (TOMLKitError, ValidationError):
|
||||
continue
|
||||
|
||||
return cards
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# llama-3.2
|
||||
"llama-3.2-1b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# llama-3.3
|
||||
"llama-3.3-70b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# qwen3
|
||||
"qwen3-0.6b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# glm 4.7 flash
|
||||
"glm-4.7-flash-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
|
||||
@@ -331,28 +127,10 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
}
|
||||
|
||||
from exo.worker.download.download_utils import ( # noqa: E402
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
@@ -366,20 +144,6 @@ class ConfigData(BaseModel):
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
architectures: list[str] | None = None
|
||||
|
||||
@property
|
||||
def supports_tensor(self) -> bool:
|
||||
return self.architectures in [
|
||||
["Glm4MoeLiteForCausalLM"],
|
||||
["DeepseekV32ForCausalLM"],
|
||||
["DeepseekV3ForCausalLM"],
|
||||
["Qwen3NextForCausalLM"],
|
||||
["Qwen3MoeForCausalLM"],
|
||||
["MiniMaxM2ForCausalLM"],
|
||||
["LlamaForCausalLM"],
|
||||
["GptOssForCausalLM"],
|
||||
]
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
@@ -407,7 +171,7 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
config_path = await download_file_with_retry(
|
||||
model_id,
|
||||
str(model_id),
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
@@ -424,7 +188,7 @@ async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_path = await download_file_with_retry(
|
||||
model_id,
|
||||
str(model_id),
|
||||
"main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
|
||||
112
src/exo/shared/models/model_meta.py
Normal file
112
src/exo/shared/models/model_meta.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import Annotated
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
|
||||
|
||||
async def get_config_data(model_id: str) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
config_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(config_path, "r") as f:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
|
||||
async def get_safetensors_size(model_id: str) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.info(
|
||||
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
|
||||
),
|
||||
)
|
||||
async with aiofiles.open(index_path, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
metadata = index_data.metadata
|
||||
if metadata is not None:
|
||||
return Memory.from_bytes(metadata.total_size)
|
||||
|
||||
info = model_info(model_id)
|
||||
if info.safetensors is None:
|
||||
raise ValueError(f"No safetensors info found for {model_id}")
|
||||
return Memory.from_bytes(info.safetensors.total)
|
||||
|
||||
_model_card_cache: dict[str, ModelCard] = {}
|
||||
|
||||
async def get_model_card(model_id: str) -> ModelCard:
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
if model_id in _model_card_cache:
|
||||
return _model_card_cache[model_id]
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
|
||||
mc = ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=False,
|
||||
)
|
||||
_model_card_cache[model_id] = mc
|
||||
return mc
|
||||
@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
|
||||
|
||||
class PlaceInstanceParams(BaseModel):
|
||||
model_id: ModelId
|
||||
model_id: str
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@@ -5,7 +5,6 @@ from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .worker.runner_response import ToolCallItem
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
@@ -26,12 +25,8 @@ class TokenChunk(BaseChunk):
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: bytes
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
|
||||
@@ -25,14 +25,6 @@ class NodeId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
def short(self) -> str:
|
||||
return self.split("/")[-1]
|
||||
|
||||
|
||||
class SessionId(CamelCaseModel):
|
||||
master_node_id: NodeId
|
||||
election_clock: int
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
from datetime import timedelta
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
@@ -47,50 +42,3 @@ class DownloadOngoing(BaseDownloadProgress):
|
||||
DownloadProgress = (
|
||||
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
|
||||
)
|
||||
|
||||
|
||||
class ModelSafetensorsIndexMetadata(BaseModel):
|
||||
total_size: PositiveInt
|
||||
|
||||
|
||||
class ModelSafetensorsIndex(BaseModel):
|
||||
metadata: ModelSafetensorsIndexMetadata | None
|
||||
weight_map: dict[str, str]
|
||||
|
||||
|
||||
class FileListEntry(BaseModel):
|
||||
type: Literal["file", "directory"]
|
||||
path: str
|
||||
size: int | None = None
|
||||
|
||||
|
||||
class RepoFileDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
speed: float
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class RepoDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class BaseRunnerResponse(TaggedModel):
|
||||
@@ -10,11 +10,6 @@ class TokenizedResponse(BaseRunnerResponse):
|
||||
prompt_tokens: int
|
||||
|
||||
|
||||
class ToolCallItem(CamelCaseModel):
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
@@ -23,9 +18,5 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -1,45 +1,72 @@
|
||||
import enum
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from anyio import Path
|
||||
|
||||
def find_dashboard() -> Path:
|
||||
dashboard = (
|
||||
_find_dashboard_in_env()
|
||||
or _find_dashboard_in_repo()
|
||||
or _find_dashboard_in_bundle()
|
||||
|
||||
class RuntimeResources(enum.Enum):
|
||||
Dashboard = enum.auto
|
||||
Resources = enum.auto
|
||||
|
||||
_dir_cache: dict[RuntimeResources, Path]
|
||||
|
||||
async def find_directory(rr: RuntimeResources) -> Path:
|
||||
dir = (
|
||||
_dir_cache.get(rr, None)
|
||||
or await _find_in_env(rr)
|
||||
or await _find_in_repo(rr)
|
||||
or await _find_in_bundle(rr)
|
||||
)
|
||||
if not dashboard:
|
||||
if not dir:
|
||||
raise FileNotFoundError(
|
||||
"Unable to locate dashboard assets - make sure the dashboard has been built, or export DASHBOARD_DIR if you've built the dashboard elsewhere."
|
||||
"Unable to locate directory - make sure the dashboard has been built and the runtime resources (model cards) exist."
|
||||
)
|
||||
return dashboard
|
||||
_dir_cache[rr] = dir
|
||||
return dir
|
||||
|
||||
|
||||
def _find_dashboard_in_env() -> Path | None:
|
||||
env = os.environ.get("DASHBOARD_DIR")
|
||||
async def _find_in_env(rr: RuntimeResources) -> Path | None:
|
||||
match rr:
|
||||
case RuntimeResources.Dashboard:
|
||||
env = os.environ.get("DASHBOARD_DIR")
|
||||
case RuntimeResources.Resources:
|
||||
env = os.environ.get("RESOURCES_DIR")
|
||||
if not env:
|
||||
return None
|
||||
resolved_env = Path(env).expanduser().resolve()
|
||||
resolved_env = await (await Path(env).expanduser()).resolve()
|
||||
|
||||
return resolved_env
|
||||
|
||||
|
||||
def _find_dashboard_in_repo() -> Path | None:
|
||||
current_module = Path(__file__).resolve()
|
||||
async def _find_in_repo(rr: RuntimeResources) -> Path | None:
|
||||
current_module = await Path(__file__).resolve()
|
||||
for parent in current_module.parents:
|
||||
build = parent / "dashboard" / "build"
|
||||
if build.is_dir() and (build / "index.html").exists():
|
||||
return build
|
||||
match rr:
|
||||
case RuntimeResources.Dashboard:
|
||||
build = parent / "dashboard" / "build"
|
||||
if await build.is_dir() and await (build / "index.html").exists():
|
||||
return build
|
||||
case RuntimeResources.Resources:
|
||||
res = parent / "resources"
|
||||
if await res.is_dir():
|
||||
return res
|
||||
return None
|
||||
|
||||
|
||||
def _find_dashboard_in_bundle() -> Path | None:
|
||||
async def _find_in_bundle(rr: RuntimeResources) -> Path | None:
|
||||
frozen_root = cast(str | None, getattr(sys, "_MEIPASS", None))
|
||||
if frozen_root is None:
|
||||
return None
|
||||
candidate = Path(frozen_root) / "dashboard"
|
||||
if candidate.is_dir():
|
||||
return candidate
|
||||
|
||||
match rr:
|
||||
case RuntimeResources.Dashboard:
|
||||
candidate = Path(frozen_root) / "dashboard"
|
||||
if await candidate.is_dir():
|
||||
return candidate
|
||||
case RuntimeResources.Resources:
|
||||
candidate = Path(frozen_root) / "resources"
|
||||
if await candidate.is_dir():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
@@ -17,20 +17,17 @@ import aiohttp
|
||||
import certifi
|
||||
from loguru import logger
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
DirectoryPath,
|
||||
Field,
|
||||
PositiveInt,
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadProgressData,
|
||||
FileListEntry,
|
||||
ModelSafetensorsIndex,
|
||||
RepoDownloadProgress,
|
||||
RepoFileDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import DownloadProgressData
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
@@ -40,6 +37,53 @@ from exo.worker.download.huggingface_utils import (
|
||||
)
|
||||
|
||||
|
||||
class ModelSafetensorsIndexMetadata(BaseModel):
|
||||
total_size: PositiveInt
|
||||
|
||||
|
||||
class ModelSafetensorsIndex(BaseModel):
|
||||
metadata: ModelSafetensorsIndexMetadata | None
|
||||
weight_map: dict[str, str]
|
||||
|
||||
|
||||
class FileListEntry(BaseModel):
|
||||
type: Literal["file", "directory"]
|
||||
path: str
|
||||
size: int | None = None
|
||||
|
||||
|
||||
class RepoFileDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
speed: float
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class RepoDownloadProgress(BaseModel):
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
def trim_etag(etag: str) -> str:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
return etag[1:-1]
|
||||
@@ -81,12 +125,12 @@ def map_repo_download_progress_to_download_progress_data(
|
||||
)
|
||||
|
||||
|
||||
def build_model_path(model_id: ModelId) -> DirectoryPath:
|
||||
return EXO_MODELS_DIR / model_id.normalize()
|
||||
def build_model_path(model_id: str) -> DirectoryPath:
|
||||
return EXO_MODELS_DIR / model_id.replace("/", "--")
|
||||
|
||||
|
||||
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
|
||||
return (await ensure_models_dir()) / model_id.normalize()
|
||||
async def resolve_model_path_for_repo(repo_id: str) -> Path:
|
||||
return (await ensure_models_dir()) / repo_id.replace("/", "--")
|
||||
|
||||
|
||||
async def ensure_models_dir() -> Path:
|
||||
@@ -94,8 +138,8 @@ async def ensure_models_dir() -> Path:
|
||||
return EXO_MODELS_DIR
|
||||
|
||||
|
||||
async def delete_model(model_id: ModelId) -> bool:
|
||||
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||
async def delete_model(repo_id: str) -> bool:
|
||||
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
|
||||
if not await aios.path.exists(model_dir):
|
||||
return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
@@ -120,17 +164,19 @@ async def seed_models(seed_dir: str | Path):
|
||||
|
||||
|
||||
async def fetch_file_list_with_cache(
|
||||
model_id: ModelId, revision: str = "main", recursive: bool = False
|
||||
repo_id: str, revision: str = "main", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
target_dir = (
|
||||
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
|
||||
)
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
cache_file = (
|
||||
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
)
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||
@@ -138,25 +184,25 @@ async def fetch_file_list_with_cache(
|
||||
|
||||
|
||||
async def fetch_file_list_with_retry(
|
||||
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||
return await _fetch_file_list(repo_id, revision, path, recursive)
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
|
||||
raise Exception(
|
||||
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
||||
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_file_list(
|
||||
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
|
||||
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
|
||||
) -> list[FileListEntry]:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_download_headers()
|
||||
@@ -173,7 +219,7 @@ async def _fetch_file_list(
|
||||
files.append(FileListEntry.model_validate(item))
|
||||
elif item.type == "directory" and recursive:
|
||||
subfiles = await _fetch_file_list(
|
||||
model_id, revision, item.path, recursive
|
||||
repo_id, revision, item.path, recursive
|
||||
)
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
@@ -230,10 +276,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
|
||||
|
||||
|
||||
async def file_meta(
|
||||
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
|
||||
repo_id: str, revision: str, path: str, redirected_location: str | None = None
|
||||
) -> tuple[int, str]:
|
||||
url = (
|
||||
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
|
||||
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
if redirected_location is None
|
||||
else f"{get_hf_endpoint()}{redirected_location}"
|
||||
)
|
||||
@@ -252,7 +298,7 @@ async def file_meta(
|
||||
return content_length, etag
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = r.headers.get("location")
|
||||
return await file_meta(model_id, revision, path, redirected_location)
|
||||
return await file_meta(repo_id, revision, path, redirected_location)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
@@ -264,7 +310,7 @@ async def file_meta(
|
||||
|
||||
|
||||
async def download_file_with_retry(
|
||||
model_id: ModelId,
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
@@ -274,23 +320,23 @@ async def download_file_with_retry(
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
repo_id, revision, path, target_dir, on_progress
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||
raise e
|
||||
logger.error(
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
|
||||
|
||||
async def _download_file(
|
||||
model_id: ModelId,
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
@@ -299,7 +345,7 @@ async def _download_file(
|
||||
if await aios.path.exists(target_dir / path):
|
||||
return target_dir / path
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
length, etag = await file_meta(repo_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
partial_path = target_dir / f"{path}.partial"
|
||||
resume_byte_pos = (
|
||||
@@ -308,7 +354,7 @@ async def _download_file(
|
||||
else None
|
||||
)
|
||||
if resume_byte_pos != length:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_download_headers()
|
||||
if resume_byte_pos:
|
||||
headers["Range"] = f"bytes={resume_byte_pos}-"
|
||||
@@ -348,7 +394,7 @@ async def _download_file(
|
||||
|
||||
def calculate_repo_progress(
|
||||
shard: ShardMetadata,
|
||||
model_id: ModelId,
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
file_progress: dict[str, RepoFileDownloadProgress],
|
||||
all_start_time: float,
|
||||
@@ -377,7 +423,7 @@ def calculate_repo_progress(
|
||||
else "not_started"
|
||||
)
|
||||
return RepoDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_id=repo_id,
|
||||
repo_revision=revision,
|
||||
shard=shard,
|
||||
completed_files=len(
|
||||
@@ -396,11 +442,11 @@ def calculate_repo_progress(
|
||||
)
|
||||
|
||||
|
||||
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_file = await download_file_with_retry(
|
||||
model_id, revision, "model.safetensors.index.json", target_dir
|
||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
||||
)
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
@@ -431,6 +477,53 @@ async def get_downloaded_size(path: Path) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
async def download_progress_for_local_path(
|
||||
repo_id: str, shard: ShardMetadata, local_path: Path
|
||||
) -> RepoDownloadProgress:
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
total_files = 0
|
||||
total_bytes = 0
|
||||
|
||||
if await aios.path.isdir(local_path):
|
||||
for root, _, files in os.walk(local_path):
|
||||
for f in files:
|
||||
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
|
||||
file_path = Path(root) / f
|
||||
size = (await aios.stat(file_path)).st_size
|
||||
rel_path = str(file_path.relative_to(local_path))
|
||||
file_progress[rel_path] = RepoFileDownloadProgress(
|
||||
repo_id=repo_id,
|
||||
repo_revision="local",
|
||||
file_path=rel_path,
|
||||
downloaded=Memory.from_bytes(size),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(size),
|
||||
speed=0,
|
||||
eta=timedelta(0),
|
||||
status="complete",
|
||||
start_time=time.time(),
|
||||
)
|
||||
total_files += 1
|
||||
total_bytes += size
|
||||
else:
|
||||
raise ValueError(f"Local path {local_path} is not a directory")
|
||||
|
||||
return RepoDownloadProgress(
|
||||
repo_id=repo_id,
|
||||
repo_revision="local",
|
||||
shard=shard,
|
||||
completed_files=total_files,
|
||||
total_files=total_files,
|
||||
downloaded_bytes=Memory.from_bytes(total_bytes),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(total_bytes),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(0),
|
||||
status="complete",
|
||||
file_progress=file_progress,
|
||||
)
|
||||
|
||||
|
||||
async def download_shard(
|
||||
shard: ShardMetadata,
|
||||
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
@@ -441,6 +534,14 @@ async def download_shard(
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=}")
|
||||
|
||||
# Handle local paths
|
||||
if await aios.path.exists(str(shard.model_card.model_id)):
|
||||
logger.info(f"Using local model path {shard.model_card.model_id}")
|
||||
local_path = Path(str(shard.model_card.model_id))
|
||||
return local_path, await download_progress_for_local_path(
|
||||
str(shard.model_card.model_id), shard, local_path
|
||||
)
|
||||
|
||||
revision = "main"
|
||||
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
|
||||
"/", "--"
|
||||
@@ -451,14 +552,13 @@ async def download_shard(
|
||||
if not allow_patterns:
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
shard.model_card.model_id, revision, recursive=True
|
||||
str(shard.model_card.model_id), revision, recursive=True
|
||||
)
|
||||
filtered_file_list = list(
|
||||
filter_repo_objects(
|
||||
@@ -492,7 +592,7 @@ async def download_shard(
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=shard.model_card.model_id,
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(curr_bytes),
|
||||
@@ -509,7 +609,7 @@ async def download_shard(
|
||||
shard,
|
||||
calculate_repo_progress(
|
||||
shard,
|
||||
shard.model_card.model_id,
|
||||
str(shard.model_card.model_id),
|
||||
revision,
|
||||
file_progress,
|
||||
all_start_time,
|
||||
@@ -519,7 +619,7 @@ async def download_shard(
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=shard.model_card.model_id,
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(downloaded_bytes),
|
||||
@@ -543,7 +643,7 @@ async def download_shard(
|
||||
async def download_with_semaphore(file: FileListEntry) -> None:
|
||||
async with semaphore:
|
||||
await download_file_with_retry(
|
||||
shard.model_card.model_id,
|
||||
str(shard.model_card.model_id),
|
||||
revision,
|
||||
file.path,
|
||||
target_dir,
|
||||
@@ -557,7 +657,7 @@ async def download_shard(
|
||||
*[download_with_semaphore(file) for file in filtered_file_list]
|
||||
)
|
||||
final_repo_progress = calculate_repo_progress(
|
||||
shard, shard.model_card.model_id, revision, file_progress, all_start_time
|
||||
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
|
||||
)
|
||||
await on_progress(shard, final_repo_progress)
|
||||
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Awaitable
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, get_model_cards
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -18,7 +18,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
)
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
async def build_base_shard(model_id: str) -> ShardMetadata:
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
@@ -30,7 +30,7 @@ async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
)
|
||||
|
||||
|
||||
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
|
||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
|
||||
base_shard = await build_base_shard(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=base_shard.model_card,
|
||||
@@ -147,7 +147,7 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
self,
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
async def _status_for_model(
|
||||
model_id: ModelId,
|
||||
model_id: str,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
"""Helper coroutine that builds the shard for a model and gets its download status."""
|
||||
shard = await build_full_shard(model_id)
|
||||
@@ -158,7 +158,7 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
# Kick off download status coroutines concurrently
|
||||
tasks = [
|
||||
asyncio.create_task(_status_for_model(model_card.model_id))
|
||||
for model_card in MODEL_CARDS.values()
|
||||
for model_card in await get_model_cards()
|
||||
]
|
||||
|
||||
for task in asyncio.as_completed(tasks):
|
||||
|
||||
@@ -83,11 +83,11 @@ class CustomMlxLayer(nn.Module):
|
||||
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
|
||||
@property
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, self["_original_layer"])
|
||||
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -96,7 +96,7 @@ class CustomMlxLayer(nn.Module):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
original_layer = cast(_LayerCallable, self["_original_layer"])
|
||||
original_layer = object.__getattribute__(self, "_original_layer")
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
@@ -334,7 +334,7 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
@@ -383,6 +383,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
@@ -297,7 +296,7 @@ def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerW
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
"""
|
||||
Get the EOS token IDs for a model based on its ID.
|
||||
|
||||
@@ -321,9 +320,7 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
|
||||
return None
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(
|
||||
model_id: ModelId, model_path: Path
|
||||
) -> TokenizerWrapper:
|
||||
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
|
||||
|
||||
@@ -413,6 +413,11 @@ class Worker:
|
||||
)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = SocketConnection(
|
||||
# nonsense multiaddr
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
@@ -433,9 +438,6 @@ class Worker:
|
||||
for conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn.edge, SocketConnection):
|
||||
continue
|
||||
# ignore mDNS discovered connections
|
||||
if conn.edge.sink_multiaddr.port != 52415:
|
||||
continue
|
||||
if (
|
||||
conn.sink not in conns
|
||||
or conn.edge.sink_multiaddr.ip_address
|
||||
@@ -449,7 +451,7 @@ class Worker:
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Fetching and emitting existing download progress...")
|
||||
logger.info("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
@@ -480,7 +482,7 @@ class Worker:
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.debug("Done emitting existing download progress.")
|
||||
logger.info("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -13,10 +11,9 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -36,8 +33,6 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallItem,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -207,13 +202,7 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# TODO: Add call parser here
|
||||
if (
|
||||
tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
or tokenizer.tool_call_start
|
||||
or tokenizer.tool_call_end
|
||||
) is not None:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tokenizer)
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
@@ -223,7 +212,7 @@ def main(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token, # hang on --- is this wrong?? this is totally wrong
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
@@ -232,18 +221,6 @@ def main(
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
idx=4,
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
@@ -352,50 +329,6 @@ def parse_thinking_models(
|
||||
yield response
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
assert tokenizer.tool_call_start is not None
|
||||
assert tokenizer.tool_call_end is not None
|
||||
tool_parser = cast(
|
||||
Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None,
|
||||
tokenizer.tool_parser,
|
||||
) # first arg has followup args, but we don't care about them rn
|
||||
assert tool_parser is not None
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tokenizer.tool_call_start:
|
||||
in_tool_call = True
|
||||
continue
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if response.text == tokenizer.tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
if isinstance(parsed, list):
|
||||
tools = [
|
||||
ToolCallItem.model_validate_json(json.dumps(tool))
|
||||
for tool in parsed
|
||||
]
|
||||
else:
|
||||
tools = [ToolCallItem.model_validate_json(json.dumps(parsed))]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
except (json.JSONDecodeError, ValidationError):
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -11,9 +11,8 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
|
||||
@@ -18,7 +18,6 @@ def _check_model_exists() -> bool:
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def download_tokenizer_files(model_id: ModelId) -> Path:
|
||||
async def download_tokenizer_files(model_id: str) -> Path:
|
||||
"""Download only the tokenizer-related files for a model."""
|
||||
target_dir = await ensure_models_dir() / model_id.normalize()
|
||||
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
|
||||
@@ -72,24 +72,22 @@ async def download_tokenizer_files(model_id: ModelId) -> Path:
|
||||
|
||||
|
||||
# Get a sample of models to test (one per family to keep tests fast)
|
||||
def get_test_models() -> list[ModelCard]:
|
||||
def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, ModelCard] = {}
|
||||
for card in MODEL_CARDS.values():
|
||||
families: dict[str, tuple[str, ModelCard]] = {}
|
||||
for _, card in MODEL_CARDS.items():
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = card.model_id.short().split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
|
||||
if family not in families:
|
||||
families[family] = card
|
||||
families[family] = (card.model_id.short(), card)
|
||||
|
||||
return list(families.values())
|
||||
|
||||
|
||||
TEST_MODELS: list[ModelCard] = get_test_models()
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -101,13 +99,14 @@ def event_loop():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_card",
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode and decode text correctly."""
|
||||
model_id = model_card.model_id
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
# Download tokenizer files
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
@@ -166,15 +165,16 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_card",
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_has_required_attributes(
|
||||
short_id: str, model_card: ModelCard
|
||||
) -> None:
|
||||
"""Test that tokenizer has required attributes for inference."""
|
||||
model_id = model_card.model_id
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -207,18 +207,19 @@ async def test_tokenizer_has_required_attributes(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_card",
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
|
||||
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode text containing special tokens.
|
||||
|
||||
This is critical because the actual inference path uses prompts with
|
||||
special tokens from chat templates. If special tokens aren't handled
|
||||
correctly, encoding will fail.
|
||||
"""
|
||||
model_id = model_card.model_id
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -298,14 +299,16 @@ async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
|
||||
async def test_kimi_tokenizer_specifically():
|
||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||
kimi_models = [
|
||||
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "kimi" in short_id.lower()
|
||||
]
|
||||
|
||||
if not kimi_models:
|
||||
pytest.skip("No Kimi models found in MODEL_CARDS")
|
||||
|
||||
model_card = kimi_models[0]
|
||||
model_id = model_card.model_id
|
||||
_, model_card = kimi_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
@@ -344,15 +347,17 @@ async def test_kimi_tokenizer_specifically():
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
glm_model_cards = [
|
||||
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
|
||||
glm_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "glm" in short_id.lower()
|
||||
]
|
||||
|
||||
if not glm_model_cards:
|
||||
if not glm_models:
|
||||
pytest.skip("No GLM models found in MODEL_CARDS")
|
||||
|
||||
model_card = glm_model_cards[0]
|
||||
model_id = model_card.model_id
|
||||
_, model_card = glm_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import ModelId, NodeId
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
|
||||
@@ -12,7 +12,7 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
@@ -89,22 +89,22 @@ async def tb_detection():
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
# await sd.ensure_shard(ModelId("mlx-community/Qwen3-0.6B-8bit")))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
await build_full_shard(ModelId("mlx-community/Llama-3.1-8b-bf16"))
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(await build_full_shard(ModelId("mlx-community/Qwen3-30b-A3B")))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
await build_full_shard(ModelId("mlx-commmunity/gpt-oss-120b-MXFP4-Q8"))
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
await build_full_shard(ModelId("mlx-community/gpt-oss-20b-4bit"))
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
await build_full_shard(ModelId("mlx-community/GLM-4.7-8bit-gs32"))
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
await build_full_shard(ModelId("mlx-community/MiniMax-M2.1-8bit"))
|
||||
)
|
||||
|
||||
|
||||
|
||||
12
uv.lock
generated
12
uv.lock
generated
@@ -248,7 +248,6 @@ dependencies = [
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
@@ -282,7 +281,6 @@ requires-dist = [
|
||||
{ name = "pydantic", specifier = ">=2.11.7" },
|
||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||
{ name = "tiktoken", specifier = ">=0.12.0" },
|
||||
{ name = "tomlkit", specifier = ">=0.14.0" },
|
||||
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
|
||||
]
|
||||
|
||||
@@ -317,16 +315,6 @@ dev = [
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
|
||||
]
|
||||
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
|
||||
Reference in New Issue
Block a user