update api to match

gotta fix some tests
This commit is contained in:
Evan
2025-12-23 13:48:20 +00:00
parent f3de735fd6
commit 592d389262
14 changed files with 562 additions and 187 deletions

View File

@@ -33,6 +33,7 @@ dependencies = [
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
]
[project.scripts]

View File

@@ -13,6 +13,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -21,11 +27,13 @@ from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
FinishReason,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -56,7 +64,7 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
HIDE_THINKING = False
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
@@ -161,7 +169,9 @@ class API:
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/v1/chat/completions")(self.chat_completions)
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -177,17 +187,32 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=command.model_meta,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
command = CreateInstance(instance=payload.instance)
instance = payload.instance
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
raise HTTPException(
status_code=400,
detail=f"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB",
)
command = CreateInstance(
instance=instance,
)
await self._send(command)
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
)
async def get_placement(
@@ -352,32 +377,52 @@ class API:
instance_id=instance_id,
)
async def _generate_chat_stream(
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
is_thinking = False
with recv as token_chunks:
async for chunk in token_chunks:
if HIDE_THINKING:
if chunk.text == "<think>":
is_thinking = True
if chunk.text == "</think>":
is_thinking = False
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
if not (is_thinking and HIDE_THINKING):
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
break
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -392,6 +437,59 @@ class API:
await self._send(command)
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
),
finish_reason=finish_reason,
)
],
)
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
@@ -399,10 +497,12 @@ class API:
async def chat_completions(
self, payload: ChatCompletionTaskParams
) -> StreamingResponse:
"""Handle chat completions with proper streaming response."""
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -417,10 +517,13 @@ class API:
request_params=payload,
)
await self._send(command)
return StreamingResponse(
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
@@ -442,6 +545,8 @@ class API:
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
)
for card in MODEL_CARDS.values()
]
@@ -458,7 +563,7 @@ class API:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._applystate)
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
print_startup_banner(self.port)
await serve(
@@ -470,7 +575,7 @@ class API:
self.command_sender.close()
self.global_event_receiver.close()
async def _applystate(self):
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.origin != self.session_id.master_node_id:

View File

@@ -123,6 +123,8 @@ async def test_master():
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
supports_tensor=True,
),
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
@@ -180,6 +182,8 @@ async def test_master():
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
supports_tensor=True,
),
device_rank=0,
world_size=1,

View File

@@ -49,6 +49,8 @@ def model_meta() -> ModelMetadata:
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=10,
supports_tensor=True,
)
@@ -135,6 +137,8 @@ def test_get_instance_placements_one_node_exact_fit(
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {})
@@ -160,6 +164,8 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {})
@@ -185,6 +191,8 @@ def test_get_instance_placements_one_node_not_fit(
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)

View File

@@ -198,6 +198,8 @@ def test_get_shard_assignments(
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]

View File

@@ -51,6 +51,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
@@ -64,6 +66,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# "deepseek-v3.2": ModelCard(
@@ -135,6 +139,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
@@ -148,6 +154,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
@@ -162,6 +170,38 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
@@ -175,6 +215,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
@@ -189,6 +231,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
@@ -202,6 +246,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
@@ -215,6 +261,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
@@ -229,6 +277,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
@@ -242,6 +292,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
@@ -255,20 +307,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
),
),
# phi-3
"phi-3-mini": ModelCard(
short_id="phi-3-mini",
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
name="Phi 3 Mini 128k (4-bit)",
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
pretty_name="Phi 3 Mini 128k (4-bit)",
storage_size=Memory.from_mb(2099),
n_layers=32,
hidden_size=8192,
supports_tensor=True,
),
),
# qwen3
@@ -283,6 +323,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
@@ -296,6 +338,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
@@ -309,6 +353,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
@@ -322,6 +368,68 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
@@ -335,6 +443,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
@@ -348,6 +458,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
@@ -361,6 +473,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
@@ -374,77 +488,84 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# granite
"granite-3.3-2b": ModelCard(
short_id="granite-3.3-2b",
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
name="Granite 3.3 2B (FP16)",
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
pretty_name="Granite 3.3 2B (FP16)",
storage_size=Memory.from_mb(4951),
n_layers=40,
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
),
# "granite-3.3-8b": ModelCard(
# short_id="granite-3.3-8b",
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
# name="Granite 3.3 8B",
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# Needs to be quantized g32 or g16.
"glm-4.5-air-8bit": ModelCard(
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
# pretty_name="Granite 3.3 8B",
# storage_size=Memory.from_kb(15958720),
# n_layers=40,
# ),
# ),
# smol-lm
# "smol-lm-135m": ModelCard(
# short_id="smol-lm-135m",
# model_id="mlx-community/SmolLM-135M-4bit",
# name="Smol LM 135M",
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
# pretty_name="Smol LM 135M",
# storage_size=Memory.from_kb(73940),
# n_layers=30,
# ),
# ),
# gpt-oss
# "gpt-oss-120b-MXFP4-Q8": ModelCard(
# short_id="gpt-oss-120b-MXFP4-Q8",
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
# name="GPT-OSS 120B (MXFP4-Q8, MLX)",
# description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
# pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
# storage_size=Memory.from_kb(68_996_301),
# n_layers=36,
# hidden_size=2880,
# supports_tensor=True,
# ),
# ),
# "gpt-oss-20b-4bit": ModelCard(
# short_id="gpt-oss-20b-4bit",
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
# name="GPT-OSS 20B (MXFP4-Q4, MLX)",
# description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
# pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
# storage_size=Memory.from_kb(11_744_051),
# n_layers=24,
# hidden_size=2880,
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# storage_size=Memory.from_kb(133_000_000),
# n_layers=88,
# hidden_size=12288,
# supports_tensor=True,
# ),
# ),

View File

@@ -6,6 +6,7 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
@@ -25,6 +26,7 @@ class ConfigData(BaseModel):
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
@@ -106,10 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_id,
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.types.common import CommandId
from exo.shared.types.models import ModelId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -174,6 +174,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
class DeleteInstanceResponse(BaseModel):

View File

@@ -14,3 +14,5 @@ class ModelMetadata(CamelCaseModel):
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from copy import copy
from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
@@ -12,7 +13,7 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
class ShardDownloader(ABC):
@abstractmethod
async def ensure_shard(
@@ -43,34 +44,7 @@ class ShardDownloader(ABC):
Yields:
tuple[Path, RepoDownloadProgress]: The path and progress of a shard download.
"""
yield (
Path("/tmp/noop_shard"),
RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
),
)
yield (Path("/tmp/noop_shard"), NOOP_DOWNLOAD_PROGRESS)
@abstractmethod
async def get_shard_download_status_for_shard(
@@ -94,46 +68,41 @@ class NoopShardDownloader(ShardDownloader):
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
yield (
Path("/tmp/noop_shard"),
RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
),
NOOP_DOWNLOAD_PROGRESS,
)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata
) -> RepoDownloadProgress:
return RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=shard,
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)
dp = copy(NOOP_DOWNLOAD_PROGRESS)
dp.shard = shard
return dp
NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=0,
supports_tensor=False,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)

View File

@@ -2,16 +2,13 @@ import os
import loguru
from exo.shared.types.events import Event
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import MpReceiver, MpSender
logger: "loguru.Logger"
if os.getenv("EXO_TESTS") == "1":
logger = loguru.logger
logger: "loguru.Logger" = loguru.logger
def entrypoint(
@@ -30,6 +27,23 @@ def entrypoint(
logger = _logger
# Import main after setting global logger - this lets us just import logger from this module
from exo.worker.runner.runner import main
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver)
except Exception as e:
logger.opt(exception=e).warning(
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
)
event_sender.send(
RunnerStatusUpdated(
runner_id=bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=str(e)),
)
)
finally:
event_sender.close()
task_receiver.close()
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")

View File

@@ -1,5 +1,10 @@
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass
from anyio import ClosedResourceError, WouldBlock
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
@@ -14,6 +19,96 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
# Synchronous trivial sender and receiver.
@dataclass
class _State[T]:
buffer: list[T]
closed: bool = False
class MockSender[T]:
def __init__(self, _state: _State[T] | None = None):
self._state = _state or _State(buffer=[])
self._closed = False
def send(self, item: T):
if self._closed:
raise ClosedResourceError
self._state.buffer.append(item)
def close(self):
self._closed = True
self._state.closed = True
def join(self):
pass
def clone(self) -> MockSender[T]:
if self._closed:
raise ClosedResourceError
return MockSender(_state=self._state)
def clone_receiver(self) -> MockReceiver[T]:
if self._closed:
raise ClosedResourceError
return MockReceiver(_state=self._state)
class MockReceiver[T]:
def __init__(self, _state: _State[T] | None = None):
self._state = _state or _State(buffer=[])
self._closed = False
def close(self):
self._closed = True
self._state.closed = True
def join(self):
pass
def clone(self) -> MockReceiver[T]:
if self._closed:
raise ClosedResourceError
return MockReceiver(_state=self._state)
def clone_sender(self) -> MockSender[T]:
if self._closed:
raise ClosedResourceError
return MockSender(_state=self._state)
def receive_nowait(self) -> T:
if self._state.buffer:
return self._state.buffer.pop(0)
raise WouldBlock
def collect(self) -> list[T]:
out: list[T] = []
while True:
try:
out.append(self.receive_nowait())
except WouldBlock:
break
return out
async def receive_at_least(self, n: int) -> list[T]:
raise NotImplementedError
def __enter__(self):
return self
def __iter__(self) -> Iterator[T]:
while True:
try:
yield self.receive_nowait()
except WouldBlock:
break
def __exit__(self, exc_type, exc, tb):
# Don't swallow exceptions
return False
# Runner supervisor without multiprocessing logic.
@dataclass(frozen=True)
class FakeRunnerSupervisor:
bound_instance: BoundInstance
@@ -35,6 +130,8 @@ def get_pipeline_shard_metadata(
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,
supports_tensor=False,
),
device_rank=device_rank,
world_size=world_size,
@@ -67,5 +164,21 @@ def get_mlx_ring_instance(
shard_assignments=get_shard_assignments(
model_id, node_to_runner, runner_to_shard
),
hosts=[],
hosts_by_node={},
ephemeral_port=0,
)
def get_bound_mlx_ring_instance(
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
) -> BoundInstance:
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=1)
instance = get_mlx_ring_instance(
instance_id=instance_id,
model_id=model_id,
node_to_runner={node_id: runner_id},
runner_to_shard={runner_id: shard},
)
return BoundInstance(
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
)

View File

@@ -0,0 +1 @@
# TODO:

23
uv.lock generated
View File

@@ -336,6 +336,7 @@ dependencies = [
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -377,6 +378,7 @@ requires-dist = [
{ name = "mlx", specifier = ">=0.29.3" },
{ name = "mlx-lm", specifier = ">=0.28.3" },
{ name = "networkx", specifier = ">=3.5" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "protobuf", specifier = ">=6.32.0" },
{ name = "psutil", specifier = ">=7.0.0" },
{ name = "pydantic", specifier = ">=2.11.7" },
@@ -940,6 +942,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e2/c1/6dba12fdf68b02a21ac411c9df19afa66bed2540f467150ca64d246b463d/numpy-2.3.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e1708fac43ef8b419c975926ce1eaf793b0c13b7356cfab6ab0dc34c0a02ac0f", size = 18652691, upload-time = "2025-10-15T16:17:46.247Z" },
]
[[package]]
name = "openai-harmony"
version = "0.0.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3e/92/2d038d096f29179c7c9571b431f9e739f87a487121901725e23fe338dd9d/openai_harmony-0.0.8.tar.gz", hash = "sha256:6e43f98e6c242fa2de6f8ea12eab24af63fa2ed3e89c06341fb9d92632c5cbdf", size = 284777, upload-time = "2025-11-05T19:07:06.727Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/c6/2502f416d46be3ec08bb66d696cccffb57781a499e3ff2e4d7c174af4e8f/openai_harmony-0.0.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:029ec25ca74abe48fdb58eb9fdd2a8c1618581fc33ce8e5653f8a1ffbfbd9326", size = 2627806, upload-time = "2025-11-05T19:06:57.063Z" },
{ url = "https://files.pythonhosted.org/packages/d3/d2/ce6953ca87db9cae3e775024184da7d1c5cb88cead19a2d75b42f00a959c/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4f709815924ec325b9a890e6ab2bbb0ceec8e319a4e257328eb752cf36b2efc", size = 2948463, upload-time = "2025-11-05T19:06:48.17Z" },
{ url = "https://files.pythonhosted.org/packages/fa/4c/b553c9651662d6ce102ca7f3629d268b23df1abe5841e24bed81e8a8e949/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cfcfd963b50a41fc656c84d3440ca6eecdccd6c552158ce790b8f2e33dfb5a9", size = 2704083, upload-time = "2025-11-05T19:06:50.205Z" },
{ url = "https://files.pythonhosted.org/packages/9b/af/4eec8f9ab9c27bcdb444460c72cf43011d176fc44c79d6e113094ca1e152/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a3a16972aa1cee38ea958470cd04ac9a2d5ac38fdcf77ab686611246220c158", size = 2959765, upload-time = "2025-11-05T19:06:53.62Z" },
{ url = "https://files.pythonhosted.org/packages/11/3c/33f3374e4624e0e776f6b13b73c45a7ead7f9c4529f8369ed5bfcaa30cac/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4d5cfa168e74d08f8ba6d58a7e49bc7daef4d58951ec69b66b0d56f4927a68d", size = 3427031, upload-time = "2025-11-05T19:06:51.829Z" },
{ url = "https://files.pythonhosted.org/packages/25/3f/1a192b93bb47c6b44cd98ba8cc1d3d2a9308f1bb700c3017e6352da11bda/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c007d277218a50db8839e599ed78e0fffe5130f614c3f6d93ae257f282071a29", size = 2953260, upload-time = "2025-11-05T19:06:55.406Z" },
{ url = "https://files.pythonhosted.org/packages/5b/f8/93b582cad3531797c3db7c2db5400fd841538ccddfd9f5e3df61be99a630/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8565d4f5a0638da1bffde29832ed63c9e695c558611053add3b2dc0b56c92dbc", size = 3127044, upload-time = "2025-11-05T19:06:59.553Z" },
{ url = "https://files.pythonhosted.org/packages/1d/10/4327dbf87f75ae813405fd9a9b4a5cde63d506ffed0a096a440a4cabd89c/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:cbaa3bda75ef0d8836e1f8cc84af62f971b1d756d740efc95c38c3e04c0bfde2", size = 2932931, upload-time = "2025-11-05T19:07:01.437Z" },
{ url = "https://files.pythonhosted.org/packages/8a/c8/1774eec4f6f360ef57618fb8f52e3d3af245b2491bd0297513aa09eec04b/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:772922a9bd24e133950fad71eb1550836f415a88e8c77870e12d0c3bd688ddc2", size = 2996140, upload-time = "2025-11-05T19:07:03.438Z" },
{ url = "https://files.pythonhosted.org/packages/60/c3/3d1e01e2dba517a91760e4a03e4f20ffc75039a6fe584d0e6f9b5c78fd15/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:007b0476a1f331f8130783f901f1da6f5a7057af1a4891f1b6a31dec364189b5", size = 3205080, upload-time = "2025-11-05T19:07:05.078Z" },
]
[[package]]
name = "packaging"
version = "25.0"