mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
wuff
This commit is contained in:
@@ -30,6 +30,7 @@ dependencies = [
|
|||||||
"anyio==4.11.0",
|
"anyio==4.11.0",
|
||||||
"bidict>=0.23.1",
|
"bidict>=0.23.1",
|
||||||
"mlx>=0.29.3",
|
"mlx>=0.29.3",
|
||||||
|
"mlx[cpu]>=0.29.3; sys_platform == 'linux'",
|
||||||
"mlx-lm>=0.28.3",
|
"mlx-lm>=0.28.3",
|
||||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||||
"hypercorn>=0.18.0",
|
"hypercorn>=0.18.0",
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ KEEP_KV_SIZE: int | None = 1600
|
|||||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||||
CACHE_GROUP_SIZE: int = 64
|
CACHE_GROUP_SIZE: int = 64
|
||||||
KV_CACHE_BITS: int | None = 8
|
KV_CACHE_BITS: int | None = 8
|
||||||
TEMPERATURE: float = 1.0
|
|
||||||
|
|
||||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||||
TRUST_REMOTE_CODE: bool = True
|
TRUST_REMOTE_CODE: bool = True
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, cast
|
from typing import Any, Callable, cast
|
||||||
|
|
||||||
|
from mlx.core.distributed import Group
|
||||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||||
from mlx_lm.sample_utils import make_sampler
|
from mlx_lm.sample_utils import make_sampler
|
||||||
@@ -13,7 +14,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
|||||||
from exo.worker.engines.mlx.constants import (
|
from exo.worker.engines.mlx.constants import (
|
||||||
CACHE_GROUP_SIZE,
|
CACHE_GROUP_SIZE,
|
||||||
KV_CACHE_BITS,
|
KV_CACHE_BITS,
|
||||||
TEMPERATURE,
|
|
||||||
TRUST_REMOTE_CODE,
|
TRUST_REMOTE_CODE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def mx_barrier(group: mx.distributed.Group | None = None):
|
def mx_barrier(group: Group | None = None):
|
||||||
mx.eval(
|
mx.eval(
|
||||||
mx.distributed.all_sum(
|
mx.distributed.all_sum(
|
||||||
mx.array(1.0),
|
mx.array(1.0),
|
||||||
@@ -77,7 +77,7 @@ def mx_barrier(group: mx.distributed.Group | None = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
|
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||||
if group is None:
|
if group is None:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -99,15 +99,9 @@ class HostList(RootModel[list[str]]):
|
|||||||
|
|
||||||
def mlx_distributed_init(
|
def mlx_distributed_init(
|
||||||
bound_instance: BoundInstance,
|
bound_instance: BoundInstance,
|
||||||
) -> mx.distributed.Group:
|
) -> Group:
|
||||||
"""
|
"""
|
||||||
Initialize the MLX distributed (runs in thread pool).
|
Initialize MLX distributed.
|
||||||
|
|
||||||
Either hosts or mlx_ibv_devices must be provided:
|
|
||||||
- hosts: traditional host-based connectivity using MLX_HOSTFILE
|
|
||||||
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
|
|
||||||
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
|
|
||||||
- strict: if True, raise an error if the distributed backend is not available
|
|
||||||
"""
|
"""
|
||||||
rank = bound_instance.bound_shard.device_rank
|
rank = bound_instance.bound_shard.device_rank
|
||||||
logger.info(f"Starting initialization for rank {rank}")
|
logger.info(f"Starting initialization for rank {rank}")
|
||||||
@@ -154,36 +148,34 @@ def mlx_distributed_init(
|
|||||||
|
|
||||||
def initialize_mlx(
|
def initialize_mlx(
|
||||||
bound_instance: BoundInstance,
|
bound_instance: BoundInstance,
|
||||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
) -> Group | None:
|
||||||
"""
|
# should we unseed it?
|
||||||
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
|
# TODO: pass in seed from params
|
||||||
"""
|
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
|
|
||||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
|
||||||
|
return None
|
||||||
|
return mlx_distributed_init(bound_instance)
|
||||||
|
|
||||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
|
|
||||||
|
def load_mlx_items(
|
||||||
|
bound_instance: BoundInstance, group: Group | None
|
||||||
|
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||||
|
# TODO: pass temperature
|
||||||
|
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||||
logger.info("Created a sampler")
|
logger.info("Created a sampler")
|
||||||
|
|
||||||
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
|
if group is None:
|
||||||
logger.info(f"Single device used for {bound_instance.instance}")
|
logger.info(f"Single device used for {bound_instance.instance}")
|
||||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
model, _ = load_model(model_path, strict=True)
|
model, _ = load_model(model_path, strict=True)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
|
||||||
pass
|
|
||||||
# model, config = quantize_model(
|
|
||||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
|
||||||
# )
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.info("Starting distributed init")
|
logger.info("Starting distributed init")
|
||||||
group = mlx_distributed_init(bound_instance)
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
@@ -193,8 +185,6 @@ def initialize_mlx(
|
|||||||
|
|
||||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||||
|
|
||||||
logger.debug(model)
|
|
||||||
|
|
||||||
return cast(Model, model), tokenizer, sampler
|
return cast(Model, model), tokenizer, sampler
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
|||||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||||
from exo.worker.engines.mlx.utils_mlx import (
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
initialize_mlx,
|
initialize_mlx,
|
||||||
|
load_mlx_items,
|
||||||
mlx_force_oom,
|
mlx_force_oom,
|
||||||
)
|
)
|
||||||
from exo.worker.runner.bootstrap import logger
|
from exo.worker.runner.bootstrap import logger
|
||||||
@@ -66,9 +67,10 @@ def main(
|
|||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
sampler = None
|
sampler = None
|
||||||
|
group = None
|
||||||
|
|
||||||
current_status: RunnerStatus = RunnerIdle()
|
current_status: RunnerStatus = RunnerIdle()
|
||||||
logger.info("runner waiting for model")
|
logger.info("runner created")
|
||||||
event_sender.send(
|
event_sender.send(
|
||||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||||
)
|
)
|
||||||
@@ -81,7 +83,9 @@ def main(
|
|||||||
)
|
)
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||||
match task:
|
match task:
|
||||||
case ConnectToGroup() if isinstance (RunnerIdle, RunnerFailed):
|
case ConnectToGroup() if isinstance(
|
||||||
|
current_status, (RunnerIdle, RunnerFailed)
|
||||||
|
):
|
||||||
logger.info("runner connecting")
|
logger.info("runner connecting")
|
||||||
current_status = RunnerConnecting()
|
current_status = RunnerConnecting()
|
||||||
event_sender.send(
|
event_sender.send(
|
||||||
@@ -89,15 +93,12 @@ def main(
|
|||||||
runner_id=runner_id, runner_status=current_status
|
runner_id=runner_id, runner_status=current_status
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
model, tokenizer, sampler = initialize_mlx(bound_instance)
|
group = initialize_mlx(bound_instance)
|
||||||
|
|
||||||
logger.info("runner connected")
|
logger.info("runner connected")
|
||||||
current_status = RunnerConnected()
|
current_status = RunnerConnected()
|
||||||
|
|
||||||
|
case LoadModel() if isinstance(current_status, RunnerConnected):
|
||||||
case LoadModel() if isinstance(
|
|
||||||
current_status, RunnerConnected
|
|
||||||
):
|
|
||||||
current_status = RunnerLoading()
|
current_status = RunnerLoading()
|
||||||
logger.info("runner loading")
|
logger.info("runner loading")
|
||||||
event_sender.send(
|
event_sender.send(
|
||||||
@@ -106,7 +107,9 @@ def main(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
model, tokenizer, sampler = initialize_mlx(bound_instance)
|
model, tokenizer, sampler = load_mlx_items(
|
||||||
|
bound_instance, group
|
||||||
|
)
|
||||||
|
|
||||||
current_status = RunnerLoaded()
|
current_status = RunnerLoaded()
|
||||||
logger.info("runner loaded")
|
logger.info("runner loaded")
|
||||||
|
|||||||
@@ -24,3 +24,9 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
|
|||||||
|
|
||||||
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
|
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
|
||||||
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
|
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
|
||||||
|
|
||||||
|
SHUTDOWN_TASK_ID = TaskId("shutdown")
|
||||||
|
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
|
||||||
|
INITIALIZATION_TASK_ID = TaskId("initialisation")
|
||||||
|
LOAD_TASK_ID = TaskId("load")
|
||||||
|
WARMUP_TASK_ID = TaskId("warmup")
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from exo.shared.types.common import NodeId
|
from exo.shared.types.common import NodeId
|
||||||
@@ -14,6 +16,7 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
|
|||||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||||
|
|
||||||
|
|
||||||
|
# Runner supervisor without multiprocessing logic.
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class FakeRunnerSupervisor:
|
class FakeRunnerSupervisor:
|
||||||
bound_instance: BoundInstance
|
bound_instance: BoundInstance
|
||||||
@@ -35,6 +38,8 @@ def get_pipeline_shard_metadata(
|
|||||||
pretty_name=str(model_id),
|
pretty_name=str(model_id),
|
||||||
storage_size=Memory.from_mb(100000),
|
storage_size=Memory.from_mb(100000),
|
||||||
n_layers=32,
|
n_layers=32,
|
||||||
|
# hidden_size=2048,
|
||||||
|
# supports_tensor=False,
|
||||||
),
|
),
|
||||||
device_rank=device_rank,
|
device_rank=device_rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
@@ -69,3 +74,18 @@ def get_mlx_ring_instance(
|
|||||||
),
|
),
|
||||||
hosts=[],
|
hosts=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bound_mlx_ring_instance(
|
||||||
|
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
|
||||||
|
) -> BoundInstance:
|
||||||
|
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=1)
|
||||||
|
instance = get_mlx_ring_instance(
|
||||||
|
instance_id=instance_id,
|
||||||
|
model_id=model_id,
|
||||||
|
node_to_runner={node_id: runner_id},
|
||||||
|
runner_to_shard={runner_id: shard},
|
||||||
|
)
|
||||||
|
return BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,199 @@
|
|||||||
|
# Check tasks are complete before runner is ever ready.
|
||||||
|
from collections.abc import Iterable
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import exo.worker.runner.runner as mlx_runner
|
||||||
|
from exo.shared.types.api import ChatCompletionMessage
|
||||||
|
from exo.shared.types.chunks import TokenChunk
|
||||||
|
from exo.shared.types.events import (
|
||||||
|
Event,
|
||||||
|
ChunkGenerated,
|
||||||
|
RunnerStatusUpdated,
|
||||||
|
TaskAcknowledged,
|
||||||
|
TaskStatusUpdated,
|
||||||
|
)
|
||||||
|
from exo.shared.types.tasks import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionTaskParams,
|
||||||
|
ConnectToGroup,
|
||||||
|
LoadModel,
|
||||||
|
Shutdown,
|
||||||
|
StartWarmup,
|
||||||
|
TaskStatus,
|
||||||
|
Task,
|
||||||
|
)
|
||||||
|
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||||
|
from exo.shared.types.worker.runners import (
|
||||||
|
RunnerIdle,
|
||||||
|
RunnerLoaded,
|
||||||
|
RunnerLoading,
|
||||||
|
RunnerReady,
|
||||||
|
RunnerRunning,
|
||||||
|
RunnerShutdown,
|
||||||
|
RunnerConnecting,
|
||||||
|
RunnerConnected,
|
||||||
|
RunnerWarmingUp,
|
||||||
|
)
|
||||||
|
from exo.utils.channels import mp_channel
|
||||||
|
|
||||||
|
from ...constants import (
|
||||||
|
CHAT_COMPLETION_TASK_ID,
|
||||||
|
COMMAND_1_ID,
|
||||||
|
INITIALIZATION_TASK_ID,
|
||||||
|
INSTANCE_1_ID,
|
||||||
|
LOAD_TASK_ID,
|
||||||
|
MODEL_A_ID,
|
||||||
|
NODE_A,
|
||||||
|
RUNNER_1_ID,
|
||||||
|
SHUTDOWN_TASK_ID,
|
||||||
|
WARMUP_TASK_ID,
|
||||||
|
)
|
||||||
|
from ..conftest import get_bound_mlx_ring_instance
|
||||||
|
|
||||||
|
|
||||||
|
INIT_TASK = ConnectToGroup(
|
||||||
|
task_id=INITIALIZATION_TASK_ID,
|
||||||
|
instance_id=INSTANCE_1_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOAD_TASK = LoadModel(
|
||||||
|
task_id=LOAD_TASK_ID,
|
||||||
|
instance_id=INSTANCE_1_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
WARMUP_TASK = StartWarmup(
|
||||||
|
task_id=WARMUP_TASK_ID,
|
||||||
|
instance_id=INSTANCE_1_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
SHUTDOWN_TASK = Shutdown(
|
||||||
|
task_id=SHUTDOWN_TASK_ID,
|
||||||
|
instance_id=INSTANCE_1_ID,
|
||||||
|
runner_id=RUNNER_1_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
CHAT_PARAMS = ChatCompletionTaskParams(
|
||||||
|
model=str(MODEL_A_ID),
|
||||||
|
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||||
|
stream=True,
|
||||||
|
max_tokens=4,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
CHAT_TASK = ChatCompletion(
|
||||||
|
task_id=CHAT_COMPLETION_TASK_ID,
|
||||||
|
command_id=COMMAND_1_ID,
|
||||||
|
task_params=CHAT_PARAMS,
|
||||||
|
instance_id=INSTANCE_1_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
|
||||||
|
for test_event, true_event in zip(test_events, true_events, strict=True):
|
||||||
|
test_event.event_id = true_event.event_id
|
||||||
|
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setattr(mlx_runner, "initialize_mlx", lambda bound_instance: object())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mlx_runner,
|
||||||
|
"load_mlx_items",
|
||||||
|
lambda bound_instance, group: (object(), object(), object()),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(mlx_runner, "warmup_inference", lambda **kwargs: 1)
|
||||||
|
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", lambda *_: None)
|
||||||
|
|
||||||
|
def fake_generate(model, tokenizer, sampler, task):
|
||||||
|
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||||
|
|
||||||
|
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||||
|
|
||||||
|
|
||||||
|
def _run(tasks: Iterable[Task]):
|
||||||
|
bound_instance = get_bound_mlx_ring_instance(
|
||||||
|
instance_id=INSTANCE_1_ID,
|
||||||
|
model_id=MODEL_A_ID,
|
||||||
|
runner_id=RUNNER_1_ID,
|
||||||
|
node_id=NODE_A,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_sender, task_receiver = mp_channel[Task]()
|
||||||
|
event_sender, event_receiver = mp_channel[Event]()
|
||||||
|
|
||||||
|
with task_sender, event_receiver:
|
||||||
|
for t in tasks:
|
||||||
|
task_sender.send(t)
|
||||||
|
|
||||||
|
# worst monkeypatch known to man
|
||||||
|
def nothin() -> None: pass
|
||||||
|
event_sender.close = nothin
|
||||||
|
event_sender.join = nothin
|
||||||
|
task_receiver.close = nothin
|
||||||
|
task_receiver.join = nothin
|
||||||
|
|
||||||
|
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||||
|
|
||||||
|
return event_receiver.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||||
|
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||||
|
|
||||||
|
expected_chunk = ChunkGenerated(
|
||||||
|
command_id=COMMAND_1_ID,
|
||||||
|
chunk=TokenChunk(
|
||||||
|
idx=0,
|
||||||
|
model=MODEL_A_ID,
|
||||||
|
text="hi",
|
||||||
|
token_id=0,
|
||||||
|
finish_reason="stop",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_events_equal(
|
||||||
|
events,
|
||||||
|
[
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||||
|
),
|
||||||
|
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||||
|
RunnerStatusUpdated(
|
||||||
|
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||||
|
),
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||||
|
),
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||||
|
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||||
|
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||||
|
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||||
|
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||||
|
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||||
|
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||||
|
),
|
||||||
|
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||||
|
expected_chunk,
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||||
|
),
|
||||||
|
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||||
|
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||||
|
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||||
|
),
|
||||||
|
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||||
|
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||||
|
],
|
||||||
|
)
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
# TODO:
|
||||||
15
uv.lock
generated
15
uv.lock
generated
@@ -334,6 +334,7 @@ dependencies = [
|
|||||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
|
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||||
@@ -375,6 +376,7 @@ requires-dist = [
|
|||||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||||
{ name = "loguru", specifier = ">=0.7.3" },
|
{ name = "loguru", specifier = ">=0.7.3" },
|
||||||
{ name = "mlx", specifier = ">=0.29.3" },
|
{ name = "mlx", specifier = ">=0.29.3" },
|
||||||
|
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.29.3" },
|
||||||
{ name = "mlx-lm", specifier = ">=0.28.3" },
|
{ name = "mlx-lm", specifier = ">=0.28.3" },
|
||||||
{ name = "networkx", specifier = ">=3.5" },
|
{ name = "networkx", specifier = ">=3.5" },
|
||||||
{ name = "protobuf", specifier = ">=6.32.0" },
|
{ name = "protobuf", specifier = ">=6.32.0" },
|
||||||
@@ -795,6 +797,19 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" },
|
{ url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
cpu = [
|
||||||
|
{ name = "mlx-cpu", marker = "sys_platform == 'linux'" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mlx-cpu"
|
||||||
|
version = "0.29.3"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6d/ff/474abb13000ca641985084055c145a70c1214973d867979ebfe7420c2df2/mlx_cpu-0.29.3-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:e76763434a9d1d878bb0d6dd965ad319a0a63b0b1d69314e4c97d8332f5e7170", size = 10225301, upload-time = "2025-10-17T19:24:03.544Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mlx-lm"
|
name = "mlx-lm"
|
||||||
version = "0.28.3"
|
version = "0.28.3"
|
||||||
|
|||||||
Reference in New Issue
Block a user