mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Best master
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from asyncio.queues import Queue
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Sequence, final
|
||||
from typing import List, Sequence, final
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@@ -46,11 +45,11 @@ def chunk_to_response(chunk: TokenChunk) -> ChatCompletionResponse:
|
||||
|
||||
@final
|
||||
class API:
|
||||
def __init__(self, command_queue: Queue[Command], global_events: AsyncSQLiteEventStorage) -> None:
|
||||
def __init__(self, command_buffer: List[Command], global_events: AsyncSQLiteEventStorage) -> None:
|
||||
self._app = FastAPI()
|
||||
self._setup_routes()
|
||||
|
||||
self.command_queue = command_queue
|
||||
self.command_buffer = command_buffer
|
||||
self.global_events = global_events
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
@@ -105,7 +104,7 @@ class API:
|
||||
command_type=CommandTypes.CHAT_COMPLETION,
|
||||
request_params=payload,
|
||||
)
|
||||
await self.command_queue.put(request)
|
||||
self.command_buffer.append(request)
|
||||
|
||||
finished = False
|
||||
while not finished:
|
||||
@@ -139,11 +138,11 @@ class API:
|
||||
|
||||
|
||||
def start_fastapi_server(
|
||||
command_queue: Queue[Command],
|
||||
command_buffer: List[Command],
|
||||
global_events: AsyncSQLiteEventStorage,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
):
|
||||
api = API(command_queue, global_events)
|
||||
api = API(command_buffer, global_events)
|
||||
|
||||
uvicorn.run(api.app, host=host, port=port)
|
||||
@@ -1,115 +0,0 @@
|
||||
from collections.abc import Set
|
||||
from typing import Literal
|
||||
|
||||
from shared.logging.common import LogEntry, LogEntryType
|
||||
|
||||
|
||||
class MasterUninitializedLogEntry(LogEntry[Literal["master_uninitialized"]]):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["master_uninitialized"] = "master_uninitialized"
|
||||
message: str = "No master state found, creating new one."
|
||||
|
||||
|
||||
class MasterCommandReceivedLogEntry(LogEntry[Literal["master_command_received"]]):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["master_command_received"] = "master_command_received"
|
||||
command_name: str
|
||||
|
||||
|
||||
class MasterInvalidCommandReceivedLogEntry(
|
||||
LogEntry[Literal["master_invalid_command_received"]]
|
||||
):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["master_invalid_command_received"] = (
|
||||
"master_invalid_command_received"
|
||||
)
|
||||
command_name: str
|
||||
|
||||
|
||||
class MasterCommandRunnerNotRunningLogEntry(
|
||||
LogEntry[Literal["master_command_runner_not_running"]]
|
||||
):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["master_command_runner_not_running"] = (
|
||||
"master_command_runner_not_running"
|
||||
)
|
||||
message: str = "Command Runner Not Running"
|
||||
|
||||
|
||||
class MasterStateManagerStoppedLogEntry(
|
||||
LogEntry[Literal["master_state_manager_stopped"]]
|
||||
):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["master_state_manager_stopped"] = "master_state_manager_stopped"
|
||||
message: str = "State Manager Stopped"
|
||||
|
||||
|
||||
class EventCategoryUnknownLogEntry(LogEntry[Literal["event_category_unknown"]]):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["event_category_unknown"] = "event_category_unknown"
|
||||
event_category: str
|
||||
message: str = "Event Category Unknown, Skipping Event."
|
||||
|
||||
|
||||
class StateUpdateLoopAlreadyRunningLogEntry(
|
||||
LogEntry[Literal["state_update_loop_already_running"]]
|
||||
):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["state_update_loop_already_running"] = (
|
||||
"state_update_loop_already_running"
|
||||
)
|
||||
message: str = "State Update Loop Already Running"
|
||||
|
||||
|
||||
class StateUpdateLoopStartedLogEntry(LogEntry[Literal["state_update_loop_started"]]):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["state_update_loop_started"] = "state_update_loop_started"
|
||||
message: str = "State Update Loop Started"
|
||||
|
||||
|
||||
class StateUpdateLoopNotRunningLogEntry(
|
||||
LogEntry[Literal["state_update_loop_not_running"]]
|
||||
):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["state_update_loop_not_running"] = (
|
||||
"state_update_loop_not_running"
|
||||
)
|
||||
message: str = "State Update Loop Not Running"
|
||||
|
||||
|
||||
class StateUpdateLoopStoppedLogEntry(LogEntry[Literal["state_update_loop_stopped"]]):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["state_update_loop_stopped"] = "state_update_loop_stopped"
|
||||
message: str = "State Update Loop Stopped"
|
||||
|
||||
|
||||
class StateUpdateErrorLogEntry(LogEntry[Literal["state_update_error"]]):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["state_update_error"] = "state_update_error"
|
||||
error: Exception
|
||||
|
||||
|
||||
class StateUpdateEffectHandlerErrorLogEntry(
|
||||
LogEntry[Literal["state_update_effect_handler_error"]]
|
||||
):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["state_update_effect_handler_error"] = (
|
||||
"state_update_effect_handler_error"
|
||||
)
|
||||
error: Exception
|
||||
|
||||
|
||||
MasterLogEntries = (
|
||||
MasterUninitializedLogEntry
|
||||
| MasterCommandReceivedLogEntry
|
||||
| MasterInvalidCommandReceivedLogEntry
|
||||
| MasterCommandRunnerNotRunningLogEntry
|
||||
| MasterStateManagerStoppedLogEntry
|
||||
| EventCategoryUnknownLogEntry
|
||||
| StateUpdateLoopAlreadyRunningLogEntry
|
||||
| StateUpdateLoopStartedLogEntry
|
||||
| StateUpdateLoopNotRunningLogEntry
|
||||
| StateUpdateLoopStoppedLogEntry
|
||||
| StateUpdateErrorLogEntry
|
||||
| StateUpdateEffectHandlerErrorLogEntry
|
||||
)
|
||||
167
master/main.py
167
master/main.py
@@ -1,25 +1,52 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
from asyncio.queues import Queue
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from master.api import start_fastapi_server
|
||||
from master.election_callback import ElectionCallbacks
|
||||
from master.forwarder_supervisor import ForwarderSupervisor
|
||||
from shared.apply import apply
|
||||
from shared.db.sqlite.config import EventLogConfig
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from shared.models.model_cards import MODEL_CARDS
|
||||
from shared.models.model_meta import get_model_meta
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import ChunkGenerated
|
||||
from shared.types.events import (
|
||||
ChunkGenerated,
|
||||
CommandId,
|
||||
InstanceCreated,
|
||||
TaskCreated,
|
||||
)
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.events.commands import Command, CommandId
|
||||
from shared.types.events.commands import (
|
||||
ChatCompletionCommand,
|
||||
Command,
|
||||
CreateInstanceCommand,
|
||||
DeleteInstanceCommand,
|
||||
)
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import (
|
||||
InstanceParams,
|
||||
ShardAssignments,
|
||||
TypeOfInstance,
|
||||
)
|
||||
from shared.types.worker.runners import RunnerId
|
||||
from shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
|
||||
|
||||
|
||||
## TODO: Hook this up properly
|
||||
async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: CommandId):
|
||||
model_id = "testmodelabc"
|
||||
|
||||
|
||||
for i in range(10):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
# Create the event with proper types and consistent IDs
|
||||
chunk_event = ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -31,7 +58,7 @@ async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: Comm
|
||||
token_id=i
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ChunkGenerated needs to be cast to the expected BaseEvent type
|
||||
await events_log.append_events(
|
||||
[chunk_event],
|
||||
@@ -51,7 +78,7 @@ async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: Comm
|
||||
token_id=11,
|
||||
finish_reason='stop'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# ChunkGenerated needs to be cast to the expected BaseEvent type
|
||||
await events_log.append_events(
|
||||
@@ -59,6 +86,111 @@ async def fake_tokens_task(events_log: AsyncSQLiteEventStorage, command_id: Comm
|
||||
origin=NodeId()
|
||||
)
|
||||
|
||||
def get_node_id() -> NodeId:
|
||||
return NodeId() # TODO
|
||||
|
||||
class Master:
|
||||
def __init__(self, command_buffer: list[Command], global_events: AsyncSQLiteEventStorage, forwarder_binary_path: Path, logger: Logger):
|
||||
self.command_buffer = command_buffer
|
||||
self.global_events = global_events
|
||||
self.node_id = get_node_id()
|
||||
self.forwarder_supervisor = ForwarderSupervisor(
|
||||
forwarder_binary_path=forwarder_binary_path,
|
||||
logger=logger
|
||||
)
|
||||
self.election_callbacks = ElectionCallbacks(self.forwarder_supervisor, logger)
|
||||
self.logger = logger
|
||||
|
||||
async def _get_state_snapshot(self) -> State:
|
||||
# TODO: for now start from scratch every time, but we can optimize this by keeping a snapshot on disk so we don't have to re-apply all events
|
||||
return State()
|
||||
|
||||
async def run(self):
|
||||
self.state = await self._get_state_snapshot()
|
||||
|
||||
# TODO: we should clean these up on shutdown
|
||||
await self.forwarder_supervisor.start_as_replica()
|
||||
if os.getenv('EXO_RUN_AS_REPLICA') in set(['TRUE', 'true', '1']):
|
||||
await self.election_callbacks.on_became_replica()
|
||||
else:
|
||||
await self.election_callbacks.on_became_master()
|
||||
|
||||
while True:
|
||||
next_event = None
|
||||
# 1. process commands
|
||||
if len(self.command_buffer) > 0:
|
||||
# for now we do one command at a time
|
||||
next_command = self.command_buffer.pop(0)
|
||||
self.logger.info(f"got command: {next_command}")
|
||||
# TODO: validate the command
|
||||
match next_command:
|
||||
case ChatCompletionCommand():
|
||||
# 1. find a valid instance for this request, if none exists ERROR (TODO)
|
||||
instance_id = InstanceId()
|
||||
task_id = TaskId()
|
||||
# 2. publish TaskCreated event (TODO)
|
||||
next_event = TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ChatCompletionTask(
|
||||
task_id=task_id,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
instance_id=instance_id,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=next_command.request_params
|
||||
)
|
||||
)
|
||||
case DeleteInstanceCommand():
|
||||
# TODO
|
||||
pass
|
||||
case CreateInstanceCommand():
|
||||
if next_command.model_id not in MODEL_CARDS:
|
||||
raise ValueError(f"Model {next_command.model_id} not supported.")
|
||||
|
||||
# TODO: we should also support models that aren't in MODEL_CARDS
|
||||
# if it's in MODEL_CARDS, use ModelMetadata from there, otherwise interpret as a repo_id and get from huggingface
|
||||
if next_command.model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[next_command.model_id]
|
||||
model_meta = model_card.metadata
|
||||
else:
|
||||
model_meta = await get_model_meta(next_command.model_id)
|
||||
|
||||
# TODO: how do we actually schedule an instance? TODO: @@@@@@𝕾𝖊𝖙𝖍@@@@@@
|
||||
next_event = InstanceCreated(
|
||||
instance_id=InstanceId(),
|
||||
instance_params=InstanceParams(
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=next_command.model_id,
|
||||
runner_to_shard={
|
||||
RunnerId(): PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=0,
|
||||
n_layers=0
|
||||
)
|
||||
},
|
||||
node_to_runner={}
|
||||
),
|
||||
hosts=[]
|
||||
),
|
||||
instance_type=TypeOfInstance.ACTIVE,
|
||||
)
|
||||
|
||||
if next_event is not None:
|
||||
await self.global_events.append_events([next_event], origin=self.node_id)
|
||||
|
||||
# 2. get latest events
|
||||
events = await self.global_events.get_events_since(self.state.last_event_applied_idx)
|
||||
if len(events) == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# 3. for each event, apply it to the state
|
||||
for event_from_log in events:
|
||||
self.state = apply(self.state, event_from_log)
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -68,30 +200,21 @@ async def main():
|
||||
await event_log_manager.initialize()
|
||||
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
|
||||
|
||||
command_queue: Queue[Command] = asyncio.Queue()
|
||||
command_buffer: List[Command] = []
|
||||
|
||||
api_thread = threading.Thread(
|
||||
target=start_fastapi_server,
|
||||
args=(
|
||||
command_queue,
|
||||
command_buffer,
|
||||
global_events,
|
||||
),
|
||||
daemon=True
|
||||
)
|
||||
api_thread.start()
|
||||
print('Running FastAPI server in a separate thread. Listening on port 8000.')
|
||||
|
||||
while True:
|
||||
# master loop
|
||||
if not command_queue.empty():
|
||||
command = await command_queue.get()
|
||||
|
||||
print(command)
|
||||
|
||||
await fake_tokens_task(global_events, command_id=command.command_id)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
logger.info('Running FastAPI server in a separate thread. Listening on port 8000.')
|
||||
|
||||
master = Master(command_buffer, global_events, forwarder_binary_path=Path("forwarder"), logger=logger)
|
||||
await master.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
71
master/tests/test_master.py
Normal file
71
master/tests/test_master.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from master.main import Master
|
||||
from shared.db.sqlite.config import EventLogConfig
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from shared.types.events import TaskCreated
|
||||
from shared.types.events.commands import ChatCompletionCommand, Command, CommandId
|
||||
from shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
|
||||
|
||||
|
||||
def _create_forwarder_dummy_binary() -> Path:
|
||||
path = Path(tempfile.mktemp()) / "forwarder.bin"
|
||||
if not path.exists():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(b"#!/bin/sh\necho dummy forwarder && sleep 1000000\n")
|
||||
path.chmod(0o755)
|
||||
return path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
logger = Logger(name='test_master_logger')
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger=logger)
|
||||
await event_log_manager.initialize()
|
||||
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
command_buffer: List[Command] = []
|
||||
|
||||
forwarder_binary_path = _create_forwarder_dummy_binary()
|
||||
|
||||
master = Master(command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
|
||||
asyncio.create_task(master.run())
|
||||
|
||||
command_buffer.append(
|
||||
ChatCompletionCommand(
|
||||
command_id=CommandId(),
|
||||
request_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[ChatCompletionMessage(role="user", content="Hello, how are you?")]
|
||||
)
|
||||
)
|
||||
)
|
||||
while len(await global_events.get_events_since(0)) == 0:
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
events = await global_events.get_events_since(0)
|
||||
assert len(events) == 1
|
||||
assert events[0].idx_in_log == 1
|
||||
assert isinstance(events[0].event, TaskCreated)
|
||||
assert events[0].event == TaskCreated(
|
||||
task_id=events[0].event.task_id,
|
||||
task=ChatCompletionTask(
|
||||
task_id=events[0].event.task_id,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
instance_id=events[0].event.task.instance_id,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[ChatCompletionMessage(role="user", content="Hello, how are you?")]
|
||||
)
|
||||
)
|
||||
)
|
||||
assert len(command_buffer) == 0
|
||||
@@ -155,6 +155,13 @@ class AsyncSQLiteEventStorage:
|
||||
|
||||
self._logger.info("Closed SQLite event storage")
|
||||
|
||||
async def delete_all_events(self) -> None:
|
||||
"""Delete all events from the database."""
|
||||
assert self._engine is not None
|
||||
async with AsyncSession(self._engine) as session:
|
||||
await session.execute(text("DELETE FROM events"))
|
||||
await session.commit()
|
||||
|
||||
async def _initialize_database(self) -> None:
|
||||
"""Initialize database connection and create tables."""
|
||||
self._engine = create_async_engine(
|
||||
|
||||
101
shared/logger.py
101
shared/logger.py
@@ -1,101 +0,0 @@
|
||||
import logging
|
||||
import logging.handlers
|
||||
from collections.abc import Sequence, Set
|
||||
from queue import Queue
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from master.logging import MasterLogEntries
|
||||
from shared.logging.common import LogEntryType
|
||||
from worker.logging import WorkerLogEntries
|
||||
|
||||
LogEntries = Annotated[
|
||||
MasterLogEntries | WorkerLogEntries, Field(discriminator="entry_type")
|
||||
]
|
||||
LogParser: TypeAdapter[LogEntries] = TypeAdapter(LogEntries)
|
||||
|
||||
|
||||
class FilterLogByType(logging.Filter):
|
||||
def __init__(self, log_types: Set[LogEntryType]):
|
||||
super().__init__()
|
||||
self.log_types = log_types
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
message = record.getMessage()
|
||||
LogParser.validate_json(message)
|
||||
return True
|
||||
|
||||
|
||||
class LogEntry(BaseModel):
|
||||
event_type: Set[LogEntryType]
|
||||
|
||||
|
||||
class LogFilterByType(logging.Filter):
|
||||
def __init__(self, log_types: Set[LogEntryType]):
|
||||
super().__init__()
|
||||
self.log_types = log_types
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
message = record.getMessage()
|
||||
LogEntry.model_validate_json(message)
|
||||
return True
|
||||
|
||||
|
||||
def configure_logger(
|
||||
logger_name: str,
|
||||
log_level: int = logging.INFO,
|
||||
effect_handlers: Sequence[logging.Handler] | None = None,
|
||||
) -> logging.Logger:
|
||||
existing_logger = logging.Logger.manager.loggerDict.get(logger_name)
|
||||
if existing_logger is not None:
|
||||
raise RuntimeError(f"Logger with name '{logger_name}' already exists.")
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(log_level)
|
||||
logger.propagate = False
|
||||
logging.raiseExceptions = True
|
||||
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
|
||||
console_handler = RichHandler(
|
||||
rich_tracebacks=True,
|
||||
)
|
||||
console_handler.setLevel(log_level)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
if effect_handlers is None:
|
||||
effect_handlers = []
|
||||
for effect_handler in effect_handlers:
|
||||
logger.addHandler(effect_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def attach_to_queue(
|
||||
logger: logging.Logger,
|
||||
filter_with: Sequence[logging.Filter],
|
||||
queue: Queue[logging.LogRecord],
|
||||
) -> None:
|
||||
handler = logging.handlers.QueueHandler(queue)
|
||||
for log_filter in filter_with:
|
||||
handler.addFilter(log_filter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def create_queue_listener(
|
||||
log_queue: Queue[logging.LogRecord],
|
||||
effect_handlers: Sequence[logging.Handler],
|
||||
) -> logging.handlers.QueueListener:
|
||||
listener = logging.handlers.QueueListener(
|
||||
log_queue, *effect_handlers, respect_handler_level=True
|
||||
)
|
||||
return listener
|
||||
|
||||
|
||||
def log(
|
||||
logger: logging.Logger, log_entry: LogEntries, log_level: int = logging.INFO
|
||||
) -> None:
|
||||
logger.log(log_level, log_entry.model_dump_json())
|
||||
252
shared/models/model_cards.py
Normal file
252
shared/models/model_cards.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.models import ModelMetadata
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
repo_id: str
|
||||
name: str
|
||||
description: str
|
||||
tags: List[str]
|
||||
metadata: ModelMetadata
|
||||
|
||||
|
||||
MODEL_CARDS = {
|
||||
"llama-3.3": ModelCard(
|
||||
id="llama-3.3",
|
||||
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
name="Llama 3.3 70B",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size_kilobytes=38758160,
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
"llama-3.3:70b": ModelCard(
|
||||
id="llama-3.3:70b",
|
||||
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
name="Llama 3.3 70B",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size_kilobytes=38758160,
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
"llama-3.2": ModelCard(
|
||||
id="llama-3.2",
|
||||
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
name="Llama 3.2 1B",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
pretty_name="Llama 3.2 1B",
|
||||
storage_size_kilobytes=678948,
|
||||
n_layers=16,
|
||||
),
|
||||
),
|
||||
"llama-3.2:1b": ModelCard(
|
||||
id="llama-3.2:1b",
|
||||
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
name="Llama 3.2 1B",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
pretty_name="Llama 3.2 1B",
|
||||
storage_size_kilobytes=678948,
|
||||
n_layers=16,
|
||||
),
|
||||
),
|
||||
"llama-3.2:3b": ModelCard(
|
||||
id="llama-3.2:3b",
|
||||
repo_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
|
||||
name="Llama 3.2 3B",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
|
||||
pretty_name="Llama 3.2 3B",
|
||||
storage_size_kilobytes=1765062,
|
||||
n_layers=28,
|
||||
),
|
||||
),
|
||||
"llama-3.1:8b": ModelCard(
|
||||
id="llama-3.1:8b",
|
||||
repo_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||
name="Llama 3.1 8B",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||
pretty_name="Llama 3.1 8B",
|
||||
storage_size_kilobytes=4411528,
|
||||
n_layers=32,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
id="llama-3.1-70b",
|
||||
repo_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
|
||||
name="Llama 3.1 70B",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
|
||||
pretty_name="Llama 3.1 70B",
|
||||
storage_size_kilobytes=38758160,
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
"deepseek-r1": ModelCard(
|
||||
id="deepseek-r1",
|
||||
repo_id="mlx-community/DeepSeek-R1-4bit",
|
||||
name="DeepSeek R1 671B (4-bit)",
|
||||
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-R1-4bit",
|
||||
pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
storage_size_kilobytes=409706307,
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
"deepseek-r1:671b": ModelCard(
|
||||
id="deepseek-r1:671b",
|
||||
repo_id="mlx-community/DeepSeek-R1-4bit",
|
||||
name="DeepSeek R1 671B",
|
||||
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-R1-4bit",
|
||||
pretty_name="DeepSeek R1 671B",
|
||||
storage_size_kilobytes=409706307,
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
"deepseek-v3": ModelCard(
|
||||
id="deepseek-v3",
|
||||
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
name="DeepSeek V3 4B",
|
||||
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
pretty_name="DeepSeek V3 4B",
|
||||
storage_size_kilobytes=368756663,
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
"deepseek-v3:671b": ModelCard(
|
||||
id="deepseek-v3:671b",
|
||||
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
name="DeepSeek V3 671B",
|
||||
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
pretty_name="DeepSeek V3 671B",
|
||||
storage_size_kilobytes=368756663,
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
"phi-3-mini": ModelCard(
|
||||
id="phi-3-mini",
|
||||
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
name="Phi 3 Mini 128k",
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
pretty_name="Phi 3 Mini 128k",
|
||||
storage_size_kilobytes=2099262,
|
||||
n_layers=32,
|
||||
),
|
||||
),
|
||||
"phi-3-mini:128k": ModelCard(
|
||||
id="phi-3-mini:128k",
|
||||
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
name="Phi 3 Mini 128k",
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
pretty_name="Phi 3 Mini 128k",
|
||||
storage_size_kilobytes=2099262,
|
||||
n_layers=32,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b": ModelCard(
|
||||
id="qwen3-0.6b",
|
||||
repo_id="mlx-community/Qwen3-0.6B-4bit",
|
||||
name="Qwen3 0.6B",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Qwen3-0.6B-4bit",
|
||||
pretty_name="Qwen3 0.6B",
|
||||
storage_size_kilobytes=327512,
|
||||
n_layers=28,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
id="qwen3-30b",
|
||||
repo_id="mlx-community/Qwen3-30B-A3B-4bit",
|
||||
name="Qwen3 30B (Active 3B)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/Qwen3-30B-A3B-4bit",
|
||||
pretty_name="Qwen3 30B (Active 3B)",
|
||||
storage_size_kilobytes=16772092,
|
||||
n_layers=48,
|
||||
),
|
||||
),
|
||||
"granite-3.3-2b": ModelCard(
|
||||
id="granite-3.3-2b",
|
||||
repo_id="mlx-community/granite-3.3-2b-instruct-fp16",
|
||||
name="Granite 3.3 2B",
|
||||
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/granite-3.3-2b-instruct-fp16",
|
||||
pretty_name="Granite 3.3 2B",
|
||||
storage_size_kilobytes=4948320,
|
||||
n_layers=40,
|
||||
),
|
||||
),
|
||||
"granite-3.3-8b": ModelCard(
|
||||
id="granite-3.3-8b",
|
||||
repo_id="mlx-community/granite-3.3-8b-instruct-fp16",
|
||||
name="Granite 3.3 8B",
|
||||
description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/granite-3.3-8b-instruct-fp16",
|
||||
pretty_name="Granite 3.3 8B",
|
||||
storage_size_kilobytes=15958720,
|
||||
n_layers=40,
|
||||
),
|
||||
),
|
||||
"smol-lm-135m": ModelCard(
|
||||
id="smol-lm-135m",
|
||||
repo_id="mlx-community/SmolLM-135M-4bit",
|
||||
name="Smol LM 135M",
|
||||
description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/SmolLM-135M-4bit",
|
||||
pretty_name="Smol LM 135M",
|
||||
storage_size_kilobytes=73940,
|
||||
n_layers=30,
|
||||
),
|
||||
),
|
||||
}
|
||||
89
shared/models/model_meta.py
Normal file
89
shared/models/model_meta.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
from huggingface_hub import model_info
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from shared.models.model_cards import MODEL_CARDS
|
||||
from shared.types.models import ModelMetadata
|
||||
from worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_exo_tmp,
|
||||
)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Optional[Annotated[int, Field(ge=0)]] = None
|
||||
num_layers: Optional[Annotated[int, Field(ge=0)]] = None
|
||||
n_layer: Optional[Annotated[int, Field(ge=0)]] = None
|
||||
n_layers: Optional[Annotated[int, Field(ge=0)]] = None # Sometimes used
|
||||
num_decoder_layers: Optional[Annotated[int, Field(ge=0)]] = None # Transformer models
|
||||
decoder_layers: Optional[Annotated[int, Field(ge=0)]] = None # Some architectures
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
|
||||
raise ValueError(f"No layer count found in config.json: {self.model_dump_json()}")
|
||||
|
||||
async def get_config_data(model_id: str) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
|
||||
config_path = await download_file_with_retry(model_card.repo_id, "main", "config.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes}"))
|
||||
async with aiofiles.open(config_path, 'r') as f:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
async def get_safetensors_size(model_id: str) -> int:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
|
||||
index_path = await download_file_with_retry(model_card.repo_id, "main", "model.safetensors.index.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes}"))
|
||||
async with aiofiles.open(index_path, 'r') as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
metadata = index_data.metadata
|
||||
if metadata is not None:
|
||||
return metadata.total_size
|
||||
|
||||
info = model_info(model_id)
|
||||
if info.safetensors is None:
|
||||
raise ValueError(f"No safetensors info found for {model_id}")
|
||||
return info.safetensors.total
|
||||
|
||||
_model_meta_cache: Dict[str, ModelMetadata] = {}
|
||||
async def get_model_meta(model_id: str) -> ModelMetadata:
|
||||
if model_id in _model_meta_cache:
|
||||
return _model_meta_cache[model_id]
|
||||
model_meta = await _get_model_meta(model_id)
|
||||
_model_meta_cache[model_id] = model_meta
|
||||
return model_meta
|
||||
|
||||
async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=model_id,
|
||||
pretty_name=model_id,
|
||||
storage_size_kilobytes=mem_size_bytes // 1024,
|
||||
n_layers=num_layers,
|
||||
)
|
||||
@@ -96,3 +96,9 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
|
||||
class RequestInstanceTaskParams(BaseModel):
|
||||
model_id: str
|
||||
|
||||
class DeleteInstanceTaskParams(BaseModel):
|
||||
instance_id: str
|
||||
|
||||
@@ -33,7 +33,5 @@ class EventFromEventLog[T: Event](BaseModel):
|
||||
|
||||
|
||||
|
||||
type Apply = Callable[
|
||||
[State, Event],
|
||||
State
|
||||
]
|
||||
type Apply = Callable[[State, Event], State]
|
||||
type ApplyFromEventLog = Callable[[State, EventFromEventLog[Event]], State]
|
||||
|
||||
23
shared/types/request.py
Normal file
23
shared/types/request.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
DeleteInstanceTaskParams,
|
||||
RequestInstanceTaskParams,
|
||||
)
|
||||
from shared.types.events import CommandId
|
||||
|
||||
|
||||
class ChatCompletionCommand(BaseModel):
|
||||
command_id: CommandId
|
||||
command_params: ChatCompletionTaskParams
|
||||
|
||||
class RequestInstanceCommand(BaseModel):
|
||||
command_id: CommandId
|
||||
command_params: RequestInstanceTaskParams
|
||||
|
||||
class DeleteInstanceCommand(BaseModel):
|
||||
command_id: CommandId
|
||||
command_params: DeleteInstanceTaskParams
|
||||
|
||||
type Command = ChatCompletionCommand | RequestInstanceCommand | DeleteInstanceCommand
|
||||
@@ -2,14 +2,14 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.models.model_meta import get_model_meta
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.download.model_meta import _get_model_meta # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_meta() -> ModelMetadata:
|
||||
return _get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit') # type: ignore
|
||||
async def model_meta() -> ModelMetadata:
|
||||
return await get_model_meta('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -2,14 +2,14 @@ import asyncio
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable, Dict, List, Optional
|
||||
|
||||
from shared.models.model_cards import MODEL_CARDS
|
||||
from shared.models.model_meta import get_model_meta
|
||||
from shared.types.worker.shards import (
|
||||
PartitionStrategy,
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from worker.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from worker.download.model_cards import MODEL_CARDS
|
||||
from worker.download.model_meta import get_model_meta
|
||||
from worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
repo_id: str
|
||||
name: str
|
||||
description: str
|
||||
tags: List[str]
|
||||
|
||||
MODEL_CARDS = {
|
||||
"llama-3.3": ModelCard(
|
||||
id="llama-3.3",
|
||||
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
name="Llama 3.3 70B",
|
||||
description="The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)",
|
||||
tags=[]),
|
||||
"llama-3.3:70b": ModelCard(
|
||||
id="llama-3.3:70b",
|
||||
repo_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
name="Llama 3.3 70B",
|
||||
description="The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)",
|
||||
tags=[]),
|
||||
"llama-3.2": ModelCard(
|
||||
id="llama-3.2",
|
||||
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
name="Llama 3.2 1B",
|
||||
description="Llama 3.2 is a large language model trained on the Llama 3.2 dataset.",
|
||||
tags=[]),
|
||||
"llama-3.2:1b": ModelCard(
|
||||
id="llama-3.2:1b",
|
||||
repo_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
name="Llama 3.2 1B",
|
||||
description="Llama 3.2 is a large language model trained on the Llama 3.2 dataset.",
|
||||
tags=[]),
|
||||
"llama-3.2:3b": ModelCard(
|
||||
id="llama-3.2:3b",
|
||||
repo_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
|
||||
name="Llama 3.2 3B",
|
||||
description="Llama 3.2 is a large language model trained on the Llama 3.2 dataset.",
|
||||
tags=[]),
|
||||
"llama-3.1:8b": ModelCard(
|
||||
id="llama-3.1:8b",
|
||||
repo_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||
name="Llama 3.1 8B",
|
||||
description="Llama 3.1 is a large language model trained on the Llama 3.1 dataset.",
|
||||
tags=[]),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
id="llama-3.1-70b",
|
||||
repo_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
|
||||
name="Llama 3.1 70B",
|
||||
description="Llama 3.1 is a large language model trained on the Llama 3.1 dataset.",
|
||||
tags=[]),
|
||||
"deepseek-r1": ModelCard(
|
||||
id="deepseek-r1",
|
||||
repo_id="mlx-community/DeepSeek-R1-4bit",
|
||||
name="DeepSeek R1 671B (4-bit)",
|
||||
description="DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.",
|
||||
tags=[]),
|
||||
"deepseek-r1:671b": ModelCard(
|
||||
id="deepseek-r1:671b", # TODO: make sure model_id matches up for identical models
|
||||
repo_id="mlx-community/DeepSeek-R1-4bit",
|
||||
name="DeepSeek R1 671B",
|
||||
description="DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.",
|
||||
tags=[]),
|
||||
"deepseek-v3": ModelCard(
|
||||
id="deepseek-v3",
|
||||
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
name="DeepSeek V3 4B",
|
||||
description="DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.",
|
||||
tags=[]),
|
||||
"deepseek-v3:671b": ModelCard(
|
||||
id="deepseek-v3:671b",
|
||||
repo_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
name="DeepSeek V3 671B",
|
||||
description="DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.",
|
||||
tags=[]),
|
||||
"phi-3-mini": ModelCard(
|
||||
id="phi-3-mini",
|
||||
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
name="Phi 3 Mini 128k",
|
||||
description="Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.",
|
||||
tags=[]),
|
||||
"phi-3-mini:128k": ModelCard(
|
||||
id="phi-3-mini:128k",
|
||||
repo_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
name="Phi 3 Mini 128k",
|
||||
description="Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.",
|
||||
tags=[]),
|
||||
"qwen3-0.6b": ModelCard(
|
||||
id="qwen3-0.6b",
|
||||
repo_id="mlx-community/Qwen3-0.6B-4bit",
|
||||
name="Qwen3 0.6B",
|
||||
description="Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.",
|
||||
tags=[]),
|
||||
"qwen3-30b": ModelCard(
|
||||
id="qwen3-30b",
|
||||
repo_id="mlx-community/Qwen3-30B-A3B-4bit",
|
||||
name="Qwen3 30B (Active 3B)",
|
||||
description="Qwen3 30B is a large language model trained on the Qwen3 30B dataset.",
|
||||
tags=[]),
|
||||
"granite-3.3-2b": ModelCard(
|
||||
id="granite-3.3-2b",
|
||||
repo_id="mlx-community/granite-3.3-2b-instruct-fp16",
|
||||
name="Granite 3.3 2B",
|
||||
description="Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.",
|
||||
tags=[]),
|
||||
"granite-3.3-8b": ModelCard(
|
||||
id="granite-3.3-8b",
|
||||
repo_id="mlx-community/granite-3.3-8b-instruct-fp16",
|
||||
name="Granite 3.3 8B",
|
||||
description="Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.",
|
||||
tags=[]),
|
||||
"smol-lm-135m": ModelCard(
|
||||
id="smol-lm-135m",
|
||||
repo_id="mlx-community/SmolLM-135M-4bit",
|
||||
name="Smol LM 135M",
|
||||
description="SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. ",
|
||||
tags=[]),
|
||||
}
|
||||
|
||||
def get_huggingface_id(model: str) -> str:
|
||||
if "mlx-community/" in model:
|
||||
return model
|
||||
if model not in MODEL_CARDS:
|
||||
raise ValueError(f"Model {model} not found")
|
||||
return MODEL_CARDS[model].repo_id
|
||||
|
||||
if __name__ == "__main__":
|
||||
for model in MODEL_CARDS:
|
||||
print(f"{model} -> {get_huggingface_id(model)}")
|
||||
@@ -1,124 +0,0 @@
|
||||
import json
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from shared.types.models import ModelMetadata
|
||||
from worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_exo_tmp,
|
||||
)
|
||||
from worker.download.model_cards import MODEL_CARDS
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
num_hidden_layers: Optional[Annotated[int, Field(ge=0)]]
|
||||
num_layers: Optional[Annotated[int, Field(ge=0)]]
|
||||
n_layer: Optional[Annotated[int, Field(ge=0)]]
|
||||
|
||||
async def get_config_data(model_id: str) -> Optional[ConfigData]:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
try:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
|
||||
config_path = await download_file_with_retry(model_card.repo_id, "main", "config.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes}"))
|
||||
async with aiofiles.open(config_path, 'r') as f:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
except EntryNotFoundError:
|
||||
print(f"Warning: config.json not found for {model_id}. Layers/type from config unavailable.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: Failed to parse config.json for {model_id}.")
|
||||
except Exception as e:
|
||||
print(f"Error: Error processing config.json for {model_id}: {e}")
|
||||
return None
|
||||
|
||||
def get_num_layers(config_data: Optional[ConfigData], model_id: str) -> Optional[int]:
|
||||
"""Extracts number of layers from config data."""
|
||||
if not config_data:
|
||||
return None
|
||||
|
||||
if config_data.num_hidden_layers is not None:
|
||||
return config_data.num_hidden_layers
|
||||
if config_data.num_layers is not None:
|
||||
return config_data.num_layers
|
||||
if config_data.n_layer is not None:
|
||||
return config_data.n_layer
|
||||
|
||||
print(f"Warning: No known layer key or valid number in config.json for {model_id}. Config: {config_data.model_dump_json()}")
|
||||
return None
|
||||
|
||||
async def get_safetensors_size(model_id: str) -> Optional[int]:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
try:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
target_dir = (await ensure_exo_tmp())/model_card.repo_id.replace("/", "--")
|
||||
index_path = await download_file_with_retry(model_card.repo_id, "main", "model.safetensors.index.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes}"))
|
||||
async with aiofiles.open(index_path, 'r') as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
metadata = index_data.metadata
|
||||
if metadata is not None:
|
||||
return metadata.total_size
|
||||
print(f"Warning: Could not extract total_size from safetensors index metadata for {model_id}. Metadata: {index_data.model_dump_json()}")
|
||||
|
||||
except EntryNotFoundError:
|
||||
print(f"Warning: model.safetensors.index.json not found for {model_id}.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: Failed to parse model.safetensors.index.json for {model_id}.")
|
||||
except Exception as e:
|
||||
print(f"Error: Error processing model.safetensors.index.json for {model_id}: {e}")
|
||||
|
||||
print(f"Warning: Could not determine safetensors total size from index for {model_id}. Falling back to model_info API call.")
|
||||
try:
|
||||
info = model_info(model_id)
|
||||
if info.safetensors is not None:
|
||||
return info.safetensors.total
|
||||
print(f"Warning: Could not get safetensors total size from model_info API for {model_id}. Safetensors info: {info}")
|
||||
except HfHubHTTPError as e:
|
||||
print(f"Error: HTTP Error while fetching model info from API for {model_id}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error: Error getting total size from huggingface info API for {model_id}: {e}")
|
||||
return None
|
||||
|
||||
_model_meta_cache: Dict[str, ModelMetadata] = {}
|
||||
async def get_model_meta(model_id: str) -> ModelMetadata:
|
||||
if model_id in _model_meta_cache:
|
||||
return _model_meta_cache[model_id]
|
||||
model_meta = await _get_model_meta(model_id)
|
||||
_model_meta_cache[model_id] = model_meta
|
||||
return model_meta
|
||||
|
||||
async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
num_layers_val: Optional[int] = None
|
||||
mem_size_bytes_val: Optional[int] = None
|
||||
try:
|
||||
config_data = await get_config_data(model_id)
|
||||
# get_num_layers is synchronous
|
||||
num_layers_val = get_num_layers(config_data, model_id)
|
||||
mem_size_bytes_val = await get_safetensors_size(model_id)
|
||||
|
||||
except HfHubHTTPError as e:
|
||||
print(f"Error: HTTP Error encountered for '{model_id}': {e}")
|
||||
except Exception as e:
|
||||
print(f"Error: Unexpected error during metadata fetching for '{model_id}': {e}")
|
||||
|
||||
# Fallbacks for missing metadata
|
||||
if mem_size_bytes_val is None:
|
||||
print(f"Warning: Could not determine model size for {model_id}. Defaulting to 0 bytes.")
|
||||
mem_size_bytes_val = 0
|
||||
if num_layers_val is None:
|
||||
print(f"Warning: Could not determine number of layers for {model_id}. Defaulting to 0 layers.")
|
||||
num_layers_val = 0
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=model_id,
|
||||
pretty_name=model_card.name,
|
||||
storage_size_kilobytes=mem_size_bytes_val // 1024,
|
||||
n_layers=num_layers_val,
|
||||
)
|
||||
@@ -1,13 +0,0 @@
|
||||
from collections.abc import Set
|
||||
from typing import Literal
|
||||
|
||||
from shared.logging.common import LogEntry, LogEntryType
|
||||
|
||||
|
||||
class WorkerUninitialized(LogEntry[Literal["master_uninitialized"]]):
|
||||
entry_destination: Set[LogEntryType] = {LogEntryType.cluster}
|
||||
entry_type: Literal["master_uninitialized"] = "master_uninitialized"
|
||||
message: str = "No master state found, creating new one."
|
||||
|
||||
|
||||
WorkerLogEntries = WorkerUninitialized
|
||||
@@ -3,10 +3,11 @@ import os
|
||||
from asyncio import Queue
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
from typing import AsyncGenerator, Callable, Optional
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from shared.apply import apply
|
||||
from shared.db.sqlite import AsyncSQLiteEventStorage
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events import (
|
||||
@@ -16,7 +17,6 @@ from shared.types.events import (
|
||||
RunnerStatusUpdated,
|
||||
TaskStateUpdated,
|
||||
)
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import TaskStatus
|
||||
from shared.types.worker.common import RunnerId
|
||||
@@ -74,15 +74,6 @@ class AssignedRunner(BaseModel):
|
||||
runner_status=self.status,
|
||||
)
|
||||
|
||||
# TODO: This should all be shared with the master.
|
||||
type ApplyFromEventLog = Callable[[State, EventFromEventLog[Event]], State]
|
||||
def get_apply_fn() -> ApplyFromEventLog:
|
||||
# TODO: this will get fixed in the worker-integration pr.
|
||||
def apply_fn(state: State, event_from_log: EventFromEventLog[Event]) -> State:
|
||||
return state
|
||||
|
||||
return apply_fn
|
||||
|
||||
class Worker:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -479,8 +470,6 @@ class Worker:
|
||||
# Handle state updates
|
||||
async def _loop(self):
|
||||
assert self.worker_events is not None
|
||||
self.apply_fn = get_apply_fn()
|
||||
|
||||
while True:
|
||||
# ToDo: Where do we update state? Do we initialize it from scratch & read all events in, or do we preload the state?
|
||||
|
||||
@@ -492,7 +481,7 @@ class Worker:
|
||||
|
||||
# 2. for each event, apply it to the state and run sagas
|
||||
for event_from_log in events:
|
||||
self.state = self.apply_fn(self.state, event_from_log)
|
||||
self.state = apply(self.state, event_from_log)
|
||||
|
||||
# 3. based on the updated state, we plan & execute an operation.
|
||||
op: RunnerOp | None = self.plan(self.state)
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from typing import Final, List, Optional, override
|
||||
from uuid import UUID
|
||||
|
||||
from shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.models import ModelId, ModelMetadata
|
||||
from shared.types.state import State
|
||||
@@ -20,7 +21,6 @@ from shared.types.worker.runners import (
|
||||
ShardAssignments,
|
||||
)
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.download.model_cards import MODEL_CARDS, ModelCard
|
||||
from worker.main import AssignedRunner
|
||||
|
||||
NODE_A: Final[NodeId] = NodeId(uuid=UUID("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa"))
|
||||
|
||||
Reference in New Issue
Block a user