Best master

This commit is contained in:
Alex Cheema
2025-07-24 17:12:52 +01:00
committed by GitHub
parent 3730160477
commit 67c70b22e4
18 changed files with 610 additions and 539 deletions

View File

@@ -1,8 +1,7 @@
import asyncio
import time
from asyncio.queues import Queue
from collections.abc import AsyncGenerator
from typing import Sequence, final
from typing import List, Sequence, final
import uvicorn
from fastapi import FastAPI
@@ -46,11 +45,11 @@ def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
@final
class API:
def __init__(self, command_queue: Queue[Command], global_events: AsyncSQLiteEventStorage) -> None:
def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage) -> None:
self._app = FastAPI()
self._setup_routes()
self.command_queue = command_queue
self.command_buffer = command_buffer
self.global_events = global_events
def _setup_routes(self) -> None:
@@ -105,7 +104,7 @@ class API:
command_type=CommandTypes.CHAT_COMPLETION,
request_params=payload,
)
await self.command_queue.put(request)
self.command_buffer.append(request)
finished = False
while not finished:
@@ -139,11 +138,11 @@ class API:
def start_fastapi_server(
command_queue: Queue[Command],
command_buffer: List[Command],
global_events: AsyncSQLiteEventStorage,
host: str = "0.0.0.0",
port: int = 8000,
):
api = API(command_queue, global_events)
api = API(command_buffer, global_events)
uvicorn.run(api.app, host=host, port=port)

View File

@@ -1,115 +0,0 @@
from collections.abc import Set
from typing import Literal
from shared.logging.common import LogEntry, LogEntryType
class MasterUninitializedLogEntry(LogEntry[Literal["master_uninitialized"]]):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["master_uninitialized"] = "master_uninitialized"
message: str = "No master state found, creating new one."
class MasterCommandReceivedLogEntry(LogEntry[Literal["master_command_received"]]):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["master_command_received"] = "master_command_received"
command_name: str
class MasterInvalidCommandReceivedLogEntry(
LogEntry[Literal["master_invalid_command_received"]]
):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["master_invalid_command_received"] = (
"master_invalid_command_received"
)
command_name: str
class MasterCommandRunnerNotRunningLogEntry(
LogEntry[Literal["master_command_runner_not_running"]]
):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["master_command_runner_not_running"] = (
"master_command_runner_not_running"
)
message: str = "Command Runner Not Running"
class MasterStateManagerStoppedLogEntry(
LogEntry[Literal["master_state_manager_stopped"]]
):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["master_state_manager_stopped"] = "master_state_manager_stopped"
message: str = "State Manager Stopped"
class EventCategoryUnknownLogEntry(LogEntry[Literal["event_category_unknown"]]):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["event_category_unknown"] = "event_category_unknown"
event_category: str
message: str = "Event Category Unknown, Skipping Event."
class StateUpdateLoopAlreadyRunningLogEntry(
LogEntry[Literal["state_update_loop_already_running"]]
):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["state_update_loop_already_running"] = (
"state_update_loop_already_running"
)
message: str = "State Update Loop Already Running"
class StateUpdateLoopStartedLogEntry(LogEntry[Literal["state_update_loop_started"]]):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["state_update_loop_started"] = "state_update_loop_started"
message: str = "State Update Loop Started"
class StateUpdateLoopNotRunningLogEntry(
LogEntry[Literal["state_update_loop_not_running"]]
):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["state_update_loop_not_running"] = (
"state_update_loop_not_running"
)
message: str = "State Update Loop Not Running"
class StateUpdateLoopStoppedLogEntry(LogEntry[Literal["state_update_loop_stopped"]]):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["state_update_loop_stopped"] = "state_update_loop_stopped"
message: str = "State Update Loop Stopped"
class StateUpdateErrorLogEntry(LogEntry[Literal["state_update_error"]]):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["state_update_error"] = "state_update_error"
error: Exception
class StateUpdateEffectHandlerErrorLogEntry(
LogEntry[Literal["state_update_effect_handler_error"]]
):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["state_update_effect_handler_error"] = (
"state_update_effect_handler_error"
)
error: Exception
MasterLogEntries = (
MasterUninitializedLogEntry
| MasterCommandReceivedLogEntry
| MasterInvalidCommandReceivedLogEntry
| MasterCommandRunnerNotRunningLogEntry
| MasterStateManagerStoppedLogEntry
| EventCategoryUnknownLogEntry
| StateUpdateLoopAlreadyRunningLogEntry
| StateUpdateLoopStartedLogEntry
| StateUpdateLoopNotRunningLogEntry
| StateUpdateLoopStoppedLogEntry
| StateUpdateErrorLogEntry
| StateUpdateEffectHandlerErrorLogEntry
)

View File

@@ -1,25 +1,52 @@
import asyncio
import os
import threading
from asyncio.queues import Queue
from logging import Logger
from pathlib import Path
from typing import List
from master.api import start_fastapi_server
from master.election_callback import ElectionCallbacks
from master.forwarder_supervisor import ForwarderSupervisor
from shared.apply import apply
from shared.db.sqlite.config import EventLogConfig
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.models.model_cards import MODEL_CARDS
from shared.models.model_meta import get_model_meta
from shared.types.common import NodeId
from shared.types.events import ChunkGenerated
from shared.types.events import (
ChunkGenerated,
CommandId,
InstanceCreated,
TaskCreated,
)
from shared.types.events.chunks import TokenChunk
from shared.types.events.commands import Command, CommandId
from shared.types.events.commands import (
ChatCompletionCommand,
Command,
CreateInstanceCommand,
DeleteInstanceCommand,
)
from shared.types.state import State
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
from shared.types.worker.common import InstanceId
from shared.types.worker.instances import (
InstanceParams,
ShardAssignments,
TypeOfInstance,
)
from shared.types.worker.runners import RunnerId
from shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
## TODO: Hook this up properly
async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: CommandId):
model_id = "testmodelabc"
for i in range(10):
await asyncio.sleep(0.1)
# Create the event with proper types and consistent IDs
chunk_event = ChunkGenerated(
command_id=command_id,
@@ -31,7 +58,7 @@ async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: Comm
token_id=i
)
)
# ChunkGenerated needs to be cast to the expected BaseEvent type
await events_log.append_events(
[chunk_event],
@@ -51,7 +78,7 @@ async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: Comm
token_id=11,
finish_reason='stop'
)
)
)
# ChunkGenerated needs to be cast to the expected BaseEvent type
await events_log.append_events(
@@ -59,6 +86,111 @@ async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: Comm
origin=NodeId()
)
def get_node_id() -> NodeId:
return NodeId() # TODO
class Master:
def __init__(self, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: Logger):
self.command_buffer = command_buffer
self.global_events = global_events
self.node_id = get_node_id()
self.forwarder_supervisor = ForwarderSupervisor(
forwarder_binary_path=forwarder_binary_path,
logger=logger
)
self.election_callbacks = ElectionCallbacks(self.forwarder_supervisor, logger)
self.logger = logger
async def _get_state_snapshot(self) -> State:
# TODO: for now start from scratch every time, but we can optimize this by keeping a snapshot on disk so we don't have to re-apply all events
return State()
async def run(self):
self.state = await self._get_state_snapshot()
# TODO: we should clean these up on shutdown
await self.forwarder_supervisor.start_as_replica()
if os.getenv('EXO_RUN_AS_REPLICA') in set(['TRUE', 'true', '1']):
await self.election_callbacks.on_became_replica()
else:
await self.election_callbacks.on_became_master()
while True:
next_event = None
# 1. process commands
if len(self.command_buffer) > 0:
# for now we do one command at a time
next_command = self.command_buffer.pop(0)
self.logger.info(f"got command: {next_command}")
# TODO: validate the command
match next_command:
case ChatCompletionCommand():
# 1. find a valid instance for this request, if none exists ERROR (TODO)
instance_id = InstanceId()
task_id = TaskId()
# 2. publish TaskCreated event (TODO)
next_event = TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_id=task_id,
task_type=TaskType.CHAT_COMPLETION,
instance_id=instance_id,
task_status=TaskStatus.PENDING,
task_params=next_command.request_params
)
)
case DeleteInstanceCommand():
# TODO
pass
case CreateInstanceCommand():
if next_command.model_id not in MODEL_CARDS:
raise ValueError(f"Model {next_command.model_id} not supported.")
# TODO: we should also support models that aren't in MODEL_CARDS
# if it's in MODEL_CARDS, use ModelMetadata from there, otherwise interpret as a repo_id and get from huggingface
if next_command.model_id in MODEL_CARDS:
model_card = MODEL_CARDS[next_command.model_id]
model_meta = model_card.metadata
else:
model_meta = await get_model_meta(next_command.model_id)
# TODO: how do we actually schedule an instance? TODO: @@@@@@𝕾𝖊𝖙𝖍@@@@@@
next_event = InstanceCreated(
instance_id=InstanceId(),
instance_params=InstanceParams(
shard_assignments=ShardAssignments(
model_id=next_command.model_id,
runner_to_shard={
RunnerId(): PipelineShardMetadata(
model_meta=model_meta,
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=0,
n_layers=0
)
},
node_to_runner={}
),
hosts=[]
),
instance_type=TypeOfInstance.ACTIVE,
)
if next_event is not None:
await self.global_events.append_events([next_event], origin=self.node_id)
# 2. get latest events
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
if len(events) == 0:
await asyncio.sleep(0.01)
continue
# 3. for each event, apply it to the state
for event_from_log in events:
self.state = apply(self.state, event_from_log)
async def main():
@@ -68,30 +200,21 @@ async def main():
await event_log_manager.initialize()
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
command_queue: Queue[Command] = asyncio.Queue()
command_buffer: List[Command] = []
api_thread = threading.Thread(
target=start_fastapi_server,
args=(
command_queue,
command_buffer,
global_events,
),
daemon=True
)
api_thread.start()
print('Running FastAPI server in a separate thread. Listening on port 8000.')
while True:
# master loop
if not command_queue.empty():
command = await command_queue.get()
print(command)
await fake_tokens_task(global_events, command_id=command.command_id)
await asyncio.sleep(0.01)
logger.info('Running FastAPI server in a separate thread. Listening on port 8000.')
master = Master(command_buffer, global_events, forwarder_binary_path=Path("forwarder"), logger=logger)
await master.run()
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -0,0 +1,71 @@
import asyncio
import tempfile
from logging import Logger
from pathlib import Path
from typing import List
import pytest
from master.main import Master
from shared.db.sqlite.config import EventLogConfig
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
from shared.db.sqlite.event_log_manager import EventLogManager
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from shared.types.events import TaskCreated
from shared.types.events.commands import ChatCompletionCommand, Command, CommandId
from shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
def _create_forwarder_dummy_binary() -> Path:
path = Path(tempfile.mktemp()) / "forwarder.bin"
if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(b"#!/bin/sh\necho dummy forwarder && sleep 1000000\n")
path.chmod(0o755)
return path
@pytest.mark.asyncio
async def test_master():
logger = Logger(name='test_master_logger')
event_log_manager = EventLogManager(EventLogConfig(), logger=logger)
await event_log_manager.initialize()
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
await global_events.delete_all_events()
command_buffer: List[Command] = []
forwarder_binary_path = _create_forwarder_dummy_binary()
master = Master(command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
asyncio.create_task(master.run())
command_buffer.append(
ChatCompletionCommand(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[ChatCompletionMessage(role="user", content="Hello, how are you?")]
)
)
)
while len(await global_events.get_events_since(0)) == 0:
await asyncio.sleep(0.001)
events = await global_events.get_events_since(0)
assert len(events) == 1
assert events[0].idx_in_log == 1
assert isinstance(events[0].event, TaskCreated)
assert events[0].event == TaskCreated(
task_id=events[0].event.task_id,
task=ChatCompletionTask(
task_id=events[0].event.task_id,
task_type=TaskType.CHAT_COMPLETION,
instance_id=events[0].event.task.instance_id,
task_status=TaskStatus.PENDING,
task_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[ChatCompletionMessage(role="user", content="Hello, how are you?")]
)
)
)
assert len(command_buffer) == 0

View File

@@ -155,6 +155,13 @@ class AsyncSQLiteEventStorage:
self._logger.info("Closed SQLite event storage")
async def delete_all_events(self) -> None:
"""Delete all events from the database."""
assert self._engine is not None
async with AsyncSession(self._engine) as session:
await session.execute(text("DELETE FROM events"))
await session.commit()
async def _initialize_database(self) -> None:
"""Initialize database connection and create tables."""
self._engine = create_async_engine(

View File

@@ -1,101 +0,0 @@
import logging
import logging.handlers
from collections.abc import Sequence, Set
from queue import Queue
from typing import Annotated
from pydantic import BaseModel, Field, TypeAdapter
from rich.logging import RichHandler
from master.logging import MasterLogEntries
from shared.logging.common import LogEntryType
from worker.logging import WorkerLogEntries
LogEntries = Annotated[
MasterLogEntries | WorkerLogEntries, Field(discriminator="entry_type")
]
LogParser: TypeAdapter[LogEntries] = TypeAdapter(LogEntries)
class FilterLogByType(logging.Filter):
def __init__(self, log_types: Set[LogEntryType]):
super().__init__()
self.log_types = log_types
def filter(self, record: logging.LogRecord) -> bool:
message = record.getMessage()
LogParser.validate_json(message)
return True
class LogEntry(BaseModel):
event_type: Set[LogEntryType]
class LogFilterByType(logging.Filter):
def __init__(self, log_types: Set[LogEntryType]):
super().__init__()
self.log_types = log_types
def filter(self, record: logging.LogRecord) -> bool:
message = record.getMessage()
LogEntry.model_validate_json(message)
return True
def configure_logger(
logger_name: str,
log_level: int = logging.INFO,
effect_handlers: Sequence[logging.Handler] | None = None,
) -> logging.Logger:
existing_logger = logging.Logger.manager.loggerDict.get(logger_name)
if existing_logger is not None:
raise RuntimeError(f"Logger with name '{logger_name}' already exists.")
logger = logging.getLogger(logger_name)
logger.setLevel(log_level)
logger.propagate = False
logging.raiseExceptions = True
if logger.hasHandlers():
return logger
console_handler = RichHandler(
rich_tracebacks=True,
)
console_handler.setLevel(log_level)
logger.addHandler(console_handler)
if effect_handlers is None:
effect_handlers = []
for effect_handler in effect_handlers:
logger.addHandler(effect_handler)
return logger
def attach_to_queue(
logger: logging.Logger,
filter_with: Sequence[logging.Filter],
queue: Queue[logging.LogRecord],
) -> None:
handler = logging.handlers.QueueHandler(queue)
for log_filter in filter_with:
handler.addFilter(log_filter)
logger.addHandler(handler)
def create_queue_listener(
log_queue: Queue[logging.LogRecord],
effect_handlers: Sequence[logging.Handler],
) -> logging.handlers.QueueListener:
listener = logging.handlers.QueueListener(
log_queue, *effect_handlers, respect_handler_level=True
)
return listener
def log(
logger: logging.Logger, log_entry: LogEntries, log_level: int = logging.INFO
) -> None:
logger.log(log_level, log_entry.model_dump_json())

View File

@@ -0,0 +1,252 @@
from typing import List
from pydantic import BaseModel
from shared.types.models import ModelMetadata
class ModelCard(BaseModel):
id: str
repo_id: str
name: str
description: str
tags: List[str]
metadata: ModelMetadata
MODEL_CARDS = {
"llama-3.3": ModelCard(
id="llama-3.3",
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
name="Llama 3.3 70B",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
pretty_name="Llama 3.3 70B",
storage_size_kilobytes=38758160,
n_layers=80,
),
),
"llama-3.3:70b": ModelCard(
id="llama-3.3:70b",
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
name="Llama 3.3 70B",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
pretty_name="Llama 3.3 70B",
storage_size_kilobytes=38758160,
n_layers=80,
),
),
"llama-3.2": ModelCard(
id="llama-3.2",
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
name="Llama 3.2 1B",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
pretty_name="Llama 3.2 1B",
storage_size_kilobytes=678948,
n_layers=16,
),
),
"llama-3.2:1b": ModelCard(
id="llama-3.2:1b",
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
name="Llama 3.2 1B",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
pretty_name="Llama 3.2 1B",
storage_size_kilobytes=678948,
n_layers=16,
),
),
"llama-3.2:3b": ModelCard(
id="llama-3.2:3b",
repo_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
name="Llama 3.2 3B",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
pretty_name="Llama 3.2 3B",
storage_size_kilobytes=1765062,
n_layers=28,
),
),
"llama-3.1:8b": ModelCard(
id="llama-3.1:8b",
repo_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
name="Llama 3.1 8B",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
pretty_name="Llama 3.1 8B",
storage_size_kilobytes=4411528,
n_layers=32,
),
),
"llama-3.1-70b": ModelCard(
id="llama-3.1-70b",
repo_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
name="Llama 3.1 70B",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
pretty_name="Llama 3.1 70B",
storage_size_kilobytes=38758160,
n_layers=80,
),
),
"deepseek-r1": ModelCard(
id="deepseek-r1",
repo_id="mlx-community/DeepSeek-R1-4bit",
name="DeepSeek R1 671B (4-bit)",
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-R1-4bit",
pretty_name="DeepSeek R1 671B (4-bit)",
storage_size_kilobytes=409706307,
n_layers=61,
),
),
"deepseek-r1:671b": ModelCard(
id="deepseek-r1:671b",
repo_id="mlx-community/DeepSeek-R1-4bit",
name="DeepSeek R1 671B",
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-R1-4bit",
pretty_name="DeepSeek R1 671B",
storage_size_kilobytes=409706307,
n_layers=61,
),
),
"deepseek-v3": ModelCard(
id="deepseek-v3",
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
name="DeepSeek V3 4B",
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-V3-0324-4bit",
pretty_name="DeepSeek V3 4B",
storage_size_kilobytes=368756663,
n_layers=61,
),
),
"deepseek-v3:671b": ModelCard(
id="deepseek-v3:671b",
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
name="DeepSeek V3 671B",
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/DeepSeek-V3-0324-4bit",
pretty_name="DeepSeek V3 671B",
storage_size_kilobytes=368756663,
n_layers=61,
),
),
"phi-3-mini": ModelCard(
id="phi-3-mini",
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
name="Phi 3 Mini 128k",
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
pretty_name="Phi 3 Mini 128k",
storage_size_kilobytes=2099262,
n_layers=32,
),
),
"phi-3-mini:128k": ModelCard(
id="phi-3-mini:128k",
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
name="Phi 3 Mini 128k",
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
pretty_name="Phi 3 Mini 128k",
storage_size_kilobytes=2099262,
n_layers=32,
),
),
"qwen3-0.6b": ModelCard(
id="qwen3-0.6b",
repo_id="mlx-community/Qwen3-0.6B-4bit",
name="Qwen3 0.6B",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Qwen3-0.6B-4bit",
pretty_name="Qwen3 0.6B",
storage_size_kilobytes=327512,
n_layers=28,
),
),
"qwen3-30b": ModelCard(
id="qwen3-30b",
repo_id="mlx-community/Qwen3-30B-A3B-4bit",
name="Qwen3 30B (Active 3B)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/Qwen3-30B-A3B-4bit",
pretty_name="Qwen3 30B (Active 3B)",
storage_size_kilobytes=16772092,
n_layers=48,
),
),
"granite-3.3-2b": ModelCard(
id="granite-3.3-2b",
repo_id="mlx-community/granite-3.3-2b-instruct-fp16",
name="Granite 3.3 2B",
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/granite-3.3-2b-instruct-fp16",
pretty_name="Granite 3.3 2B",
storage_size_kilobytes=4948320,
n_layers=40,
),
),
"granite-3.3-8b": ModelCard(
id="granite-3.3-8b",
repo_id="mlx-community/granite-3.3-8b-instruct-fp16",
name="Granite 3.3 8B",
description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/granite-3.3-8b-instruct-fp16",
pretty_name="Granite 3.3 8B",
storage_size_kilobytes=15958720,
n_layers=40,
),
),
"smol-lm-135m": ModelCard(
id="smol-lm-135m",
repo_id="mlx-community/SmolLM-135M-4bit",
name="Smol LM 135M",
description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
tags=[],
metadata=ModelMetadata(
model_id="mlx-community/SmolLM-135M-4bit",
pretty_name="Smol LM 135M",
storage_size_kilobytes=73940,
n_layers=30,
),
),
}

View File

@@ -0,0 +1,89 @@
from typing import Annotated, Dict, Optional
import aiofiles
from huggingface_hub import model_info
from pydantic import BaseModel, Field
from shared.models.model_cards import MODEL_CARDS
from shared.types.models import ModelMetadata
from worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_exo_tmp,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Optional[Annotated[int, Field(ge=0)]] = None
num_layers: Optional[Annotated[int, Field(ge=0)]] = None
n_layer: Optional[Annotated[int, Field(ge=0)]] = None
n_layers: Optional[Annotated[int, Field(ge=0)]] = None # Sometimes used
num_decoder_layers: Optional[Annotated[int, Field(ge=0)]] = None # Transformer models
decoder_layers: Optional[Annotated[int, Field(ge=0)]] = None # Some architectures
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(f"No layer count found in config.json: {self.model_dump_json()}")
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
model_card = MODEL_CARDS[model_id]
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
config_path = await download_file_with_retry(model_card.repo_id, "main", "config.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(config_path, 'r') as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> int:
"""Gets model size from safetensors index or falls back to HF API."""
model_card = MODEL_CARDS[model_id]
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
index_path = await download_file_with_retry(model_card.repo_id, "main", "model.safetensors.index.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(index_path, 'r') as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return metadata.total_size
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return info.safetensors.total
_model_meta_cache: Dict[str, ModelMetadata] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
return ModelMetadata(
model_id=model_id,
pretty_name=model_id,
storage_size_kilobytes=mem_size_bytes // 1024,
n_layers=num_layers,
)

View File

@@ -96,3 +96,9 @@ class ChatCompletionTaskParams(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
class RequestInstanceTaskParams(BaseModel):
model_id: str
class DeleteInstanceTaskParams(BaseModel):
instance_id: str

View File

@@ -33,7 +33,5 @@ class EventFromEventLog[T: Event](BaseModel):
type Apply = Callable[
[State, Event],
State
]
type Apply = Callable[[State, Event], State]
type ApplyFromEventLog = Callable[[State, EventFromEventLog[Event]], State]

23
shared/types/request.py Normal file
View File

@@ -0,0 +1,23 @@
from pydantic import BaseModel
from shared.types.api import (
ChatCompletionTaskParams,
DeleteInstanceTaskParams,
RequestInstanceTaskParams,
)
from shared.types.events import CommandId
class ChatCompletionCommand(BaseModel):
command_id: CommandId
command_params: ChatCompletionTaskParams
class RequestInstanceCommand(BaseModel):
command_id: CommandId
command_params: RequestInstanceTaskParams
class DeleteInstanceCommand(BaseModel):
command_id: CommandId
command_params: DeleteInstanceTaskParams
type Command = ChatCompletionCommand | RequestInstanceCommand | DeleteInstanceCommand

View File

@@ -2,14 +2,14 @@ from pathlib import Path
import pytest
from shared.models.model_meta import get_model_meta
from shared.types.models import ModelMetadata
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.model_meta import _get_model_meta # type: ignore
@pytest.fixture
def model_meta() -> ModelMetadata:
return _get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') # type: ignore
async def model_meta() -> ModelMetadata:
return await get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit')
@pytest.fixture

View File

@@ -2,14 +2,14 @@ import asyncio
from pathlib import Path
from typing import AsyncIterator, Callable, Dict, List, Optional
from shared.models.model_cards import MODEL_CARDS
from shared.models.model_meta import get_model_meta
from shared.types.worker.shards import (
PartitionStrategy,
PipelineShardMetadata,
ShardMetadata,
)
from worker.download.download_utils import RepoDownloadProgress, download_shard
from worker.download.model_cards import MODEL_CARDS
from worker.download.model_meta import get_model_meta
from worker.download.shard_downloader import ShardDownloader

View File

@@ -1,133 +0,0 @@
from typing import List
from pydantic import BaseModel
class ModelCard(BaseModel):
id: str
repo_id: str
name: str
description: str
tags: List[str]
MODEL_CARDS = {
"llama-3.3": ModelCard(
id="llama-3.3",
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
name="Llama 3.3 70B",
description="The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)",
tags=[]),
"llama-3.3:70b": ModelCard(
id="llama-3.3:70b",
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
name="Llama 3.3 70B",
description="The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)",
tags=[]),
"llama-3.2": ModelCard(
id="llama-3.2",
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
name="Llama 3.2 1B",
description="Llama 3.2 is a large language model trained on the Llama 3.2 dataset.",
tags=[]),
"llama-3.2:1b": ModelCard(
id="llama-3.2:1b",
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
name="Llama 3.2 1B",
description="Llama 3.2 is a large language model trained on the Llama 3.2 dataset.",
tags=[]),
"llama-3.2:3b": ModelCard(
id="llama-3.2:3b",
repo_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
name="Llama 3.2 3B",
description="Llama 3.2 is a large language model trained on the Llama 3.2 dataset.",
tags=[]),
"llama-3.1:8b": ModelCard(
id="llama-3.1:8b",
repo_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
name="Llama 3.1 8B",
description="Llama 3.1 is a large language model trained on the Llama 3.1 dataset.",
tags=[]),
"llama-3.1-70b": ModelCard(
id="llama-3.1-70b",
repo_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
name="Llama 3.1 70B",
description="Llama 3.1 is a large language model trained on the Llama 3.1 dataset.",
tags=[]),
"deepseek-r1": ModelCard(
id="deepseek-r1",
repo_id="mlx-community/DeepSeek-R1-4bit",
name="DeepSeek R1 671B (4-bit)",
description="DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.",
tags=[]),
"deepseek-r1:671b": ModelCard(
id="deepseek-r1:671b", # TODO: make sure model_id matches up for identical models
repo_id="mlx-community/DeepSeek-R1-4bit",
name="DeepSeek R1 671B",
description="DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.",
tags=[]),
"deepseek-v3": ModelCard(
id="deepseek-v3",
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
name="DeepSeek V3 4B",
description="DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.",
tags=[]),
"deepseek-v3:671b": ModelCard(
id="deepseek-v3:671b",
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
name="DeepSeek V3 671B",
description="DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.",
tags=[]),
"phi-3-mini": ModelCard(
id="phi-3-mini",
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
name="Phi 3 Mini 128k",
description="Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.",
tags=[]),
"phi-3-mini:128k": ModelCard(
id="phi-3-mini:128k",
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
name="Phi 3 Mini 128k",
description="Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.",
tags=[]),
"qwen3-0.6b": ModelCard(
id="qwen3-0.6b",
repo_id="mlx-community/Qwen3-0.6B-4bit",
name="Qwen3 0.6B",
description="Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.",
tags=[]),
"qwen3-30b": ModelCard(
id="qwen3-30b",
repo_id="mlx-community/Qwen3-30B-A3B-4bit",
name="Qwen3 30B (Active 3B)",
description="Qwen3 30B is a large language model trained on the Qwen3 30B dataset.",
tags=[]),
"granite-3.3-2b": ModelCard(
id="granite-3.3-2b",
repo_id="mlx-community/granite-3.3-2b-instruct-fp16",
name="Granite 3.3 2B",
description="Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.",
tags=[]),
"granite-3.3-8b": ModelCard(
id="granite-3.3-8b",
repo_id="mlx-community/granite-3.3-8b-instruct-fp16",
name="Granite 3.3 8B",
description="Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.",
tags=[]),
"smol-lm-135m": ModelCard(
id="smol-lm-135m",
repo_id="mlx-community/SmolLM-135M-4bit",
name="Smol LM 135M",
description="SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. ",
tags=[]),
}
def get_huggingface_id(model: str) -> str:
if "mlx-community/" in model:
return model
if model not in MODEL_CARDS:
raise ValueError(f"Model {model} not found")
return MODEL_CARDS[model].repo_id
if __name__ == "__main__":
for model in MODEL_CARDS:
print(f"{model} -> {get_huggingface_id(model)}")

View File

@@ -1,124 +0,0 @@
import json
from typing import Annotated, Dict, Optional
import aiofiles
from huggingface_hub import model_info
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
from pydantic import BaseModel, Field
from shared.types.models import ModelMetadata
from worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_exo_tmp,
)
from worker.download.model_cards import MODEL_CARDS
class ConfigData(BaseModel):
num_hidden_layers: Optional[Annotated[int, Field(ge=0)]]
num_layers: Optional[Annotated[int, Field(ge=0)]]
n_layer: Optional[Annotated[int, Field(ge=0)]]
async def get_config_data(model_id: str) -> Optional[ConfigData]:
"""Downloads and parses config.json for a model."""
try:
model_card = MODEL_CARDS[model_id]
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
config_path = await download_file_with_retry(model_card.repo_id, "main", "config.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(config_path, 'r') as f:
return ConfigData.model_validate_json(await f.read())
except EntryNotFoundError:
print(f"Warning: config.json not found for {model_id}. Layers/type from config unavailable.")
except json.JSONDecodeError:
print(f"Error: Failed to parse config.json for {model_id}.")
except Exception as e:
print(f"Error: Error processing config.json for {model_id}: {e}")
return None
def get_num_layers(config_data: Optional[ConfigData], model_id: str) -> Optional[int]:
"""Extracts number of layers from config data."""
if not config_data:
return None
if config_data.num_hidden_layers is not None:
return config_data.num_hidden_layers
if config_data.num_layers is not None:
return config_data.num_layers
if config_data.n_layer is not None:
return config_data.n_layer
print(f"Warning: No known layer key or valid number in config.json for {model_id}. Config: {config_data.model_dump_json()}")
return None
async def get_safetensors_size(model_id: str) -> Optional[int]:
"""Gets model size from safetensors index or falls back to HF API."""
try:
model_card = MODEL_CARDS[model_id]
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
index_path = await download_file_with_retry(model_card.repo_id, "main", "model.safetensors.index.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes}"))
async with aiofiles.open(index_path, 'r') as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return metadata.total_size
print(f"Warning: Could not extract total_size from safetensors index metadata for {model_id}. Metadata: {index_data.model_dump_json()}")
except EntryNotFoundError:
print(f"Warning: model.safetensors.index.json not found for {model_id}.")
except json.JSONDecodeError:
print(f"Error: Failed to parse model.safetensors.index.json for {model_id}.")
except Exception as e:
print(f"Error: Error processing model.safetensors.index.json for {model_id}: {e}")
print(f"Warning: Could not determine safetensors total size from index for {model_id}. Falling back to model_info API call.")
try:
info = model_info(model_id)
if info.safetensors is not None:
return info.safetensors.total
print(f"Warning: Could not get safetensors total size from model_info API for {model_id}. Safetensors info: {info}")
except HfHubHTTPError as e:
print(f"Error: HTTP Error while fetching model info from API for {model_id}: {e}")
except Exception as e:
print(f"Error: Error getting total size from huggingface info API for {model_id}: {e}")
return None
_model_meta_cache: Dict[str, ModelMetadata] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
model_card = MODEL_CARDS[model_id]
num_layers_val: Optional[int] = None
mem_size_bytes_val: Optional[int] = None
try:
config_data = await get_config_data(model_id)
# get_num_layers is synchronous
num_layers_val = get_num_layers(config_data, model_id)
mem_size_bytes_val = await get_safetensors_size(model_id)
except HfHubHTTPError as e:
print(f"Error: HTTP Error encountered for '{model_id}': {e}")
except Exception as e:
print(f"Error: Unexpected error during metadata fetching for '{model_id}': {e}")
# Fallbacks for missing metadata
if mem_size_bytes_val is None:
print(f"Warning: Could not determine model size for {model_id}. Defaulting to 0 bytes.")
mem_size_bytes_val = 0
if num_layers_val is None:
print(f"Warning: Could not determine number of layers for {model_id}. Defaulting to 0 layers.")
num_layers_val = 0
return ModelMetadata(
model_id=model_id,
pretty_name=model_card.name,
storage_size_kilobytes=mem_size_bytes_val // 1024,
n_layers=num_layers_val,
)

View File

@@ -1,13 +0,0 @@
from collections.abc import Set
from typing import Literal
from shared.logging.common import LogEntry, LogEntryType
class WorkerUninitialized(LogEntry[Literal["master_uninitialized"]]):
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
entry_type: Literal["master_uninitialized"] = "master_uninitialized"
message: str = "No master state found, creating new one."
WorkerLogEntries = WorkerUninitialized

View File

@@ -3,10 +3,11 @@ import os
from asyncio import Queue
from functools import partial
from logging import Logger
from typing import AsyncGenerator, Callable, Optional
from typing import AsyncGenerator, Optional
from pydantic import BaseModel, ConfigDict
from shared.apply import apply
from shared.db.sqlite import AsyncSQLiteEventStorage
from shared.types.common import NodeId
from shared.types.events import (
@@ -16,7 +17,6 @@ from shared.types.events import (
RunnerStatusUpdated,
TaskStateUpdated,
)
from shared.types.events.components import EventFromEventLog
from shared.types.state import State
from shared.types.tasks import TaskStatus
from shared.types.worker.common import RunnerId
@@ -74,15 +74,6 @@ class AssignedRunner(BaseModel):
runner_status=self.status,
)
# TODO: This should all be shared with the master.
type ApplyFromEventLog = Callable[[State, EventFromEventLog[Event]], State]
def get_apply_fn() -> ApplyFromEventLog:
# TODO: this will get fixed in the worker-integration pr.
def apply_fn(state: State, event_from_log: EventFromEventLog[Event]) -> State:
return state
return apply_fn
class Worker:
def __init__(
self,
@@ -479,8 +470,6 @@ class Worker:
# Handle state updates
async def _loop(self):
assert self.worker_events is not None
self.apply_fn = get_apply_fn()
while True:
# ToDo: Where do we update state? Do we initialize it from scratch & read all events in, or do we preload the state?
@@ -492,7 +481,7 @@ class Worker:
# 2. for each event, apply it to the state and run sagas
for event_from_log in events:
self.state = self.apply_fn(self.state, event_from_log)
self.state = apply(self.state, event_from_log)
# 3. based on the updated state, we plan & execute an operation.
op: RunnerOp | None = self.plan(self.state)

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from typing import Final, List, Optional, override
from uuid import UUID
from shared.models.model_cards import MODEL_CARDS, ModelCard
from shared.types.common import NodeId
from shared.types.models import ModelId, ModelMetadata
from shared.types.state import State
@@ -20,7 +21,6 @@ from shared.types.worker.runners import (
ShardAssignments,
)
from shared.types.worker.shards import PipelineShardMetadata
from worker.download.model_cards import MODEL_CARDS, ModelCard
from worker.main import AssignedRunner
NODE_A: Final[NodeId] = NodeId(uuid=UUID("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa"))