mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-30 16:51:02 -05:00
Compare commits
1 Commits
leo/profil
...
improve-di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3753f8631 |
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -166,8 +166,9 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
logger.warning(f"Error downloading shard: {type(e).__name__}")
|
||||
logger.error("Error downloading shard:", e)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
|
||||
@@ -40,7 +40,6 @@ class Node:
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
_profiling_in_progress: set[NodeId] = field(default_factory=set)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -87,7 +86,6 @@ class Node:
|
||||
else:
|
||||
api = None
|
||||
|
||||
profiling_in_progress: set[NodeId] = set()
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
@@ -98,7 +96,6 @@ class Node:
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
profiling_in_progress=profiling_in_progress,
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
@@ -136,7 +133,6 @@ class Node:
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
profiling_in_progress,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
@@ -243,7 +239,6 @@ class Node:
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_index_counter=self.event_index_counter,
|
||||
profiling_in_progress=self._profiling_in_progress,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
|
||||
@@ -8,8 +8,8 @@ from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, ClosedResourceError, EndOfStream, create_task_group
|
||||
from anyio.abc import SocketStream, TaskGroup
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
@@ -65,9 +65,7 @@ from exo.shared.types.api import (
|
||||
StartDownloadParams,
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
StreamOptions,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
@@ -115,9 +113,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk,
|
||||
command_id: CommandId,
|
||||
usage: Usage | None,
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -142,10 +138,21 @@ def chunk_to_response(
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
return card
|
||||
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -267,7 +274,7 @@ class API:
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
model_card=await ModelCard.load(payload.model_id),
|
||||
model_card=await resolve_model_card(payload.model_id),
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
@@ -284,7 +291,7 @@ class API:
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
instance = payload.instance
|
||||
model_card = await ModelCard.load(instance.shard_assignments.model_id)
|
||||
model_card = await resolve_model_card(instance.shard_assignments.model_id)
|
||||
required_memory = model_card.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
@@ -312,7 +319,7 @@ class API:
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
) -> Instance:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
model_card = await resolve_model_card(model_id)
|
||||
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
@@ -515,10 +522,9 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, stream_options: StreamOptions | None = None
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
include_usage = stream_options.include_usage if stream_options else False
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
@@ -534,10 +540,8 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
usage = chunk.usage if include_usage else None
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id, usage=usage
|
||||
chunk, command_id
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
@@ -553,9 +557,8 @@ class API:
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: ModelId | None = None
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -580,9 +583,6 @@ class API:
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.usage is not None:
|
||||
usage = chunk.usage
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
@@ -604,7 +604,6 @@ class API:
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
@@ -612,7 +611,7 @@ class API:
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: ModelId | None = None
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
@@ -665,7 +664,7 @@ class API:
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
)
|
||||
@@ -674,7 +673,7 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await ModelCard.load(ModelId(payload.model))
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -692,7 +691,7 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, payload.stream_options),
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -701,7 +700,7 @@ class API:
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await ModelCard.load(ModelId(payload.model))
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -721,12 +720,12 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(self, model: ModelId) -> ModelId:
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
"""
|
||||
model_card = await ModelCard.load(model)
|
||||
model_card = await resolve_model_card(ModelId(model))
|
||||
resolved_model = model_card.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
@@ -772,7 +771,7 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1017,7 +1016,7 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
@@ -1038,7 +1037,7 @@ class API:
|
||||
self,
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: ModelId,
|
||||
model: str,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
@@ -1133,7 +1132,7 @@ class API:
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
model=model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1189,7 +1188,7 @@ class API:
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
model=model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1250,7 +1249,6 @@ class API:
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
tg.start_soon(self._run_bandwidth_server)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
@@ -1258,43 +1256,6 @@ class API:
|
||||
shutdown_trigger=lambda: anyio.sleep_forever(),
|
||||
)
|
||||
|
||||
async def _run_bandwidth_server(self):
|
||||
"""TCP server for iperf-like bandwidth testing."""
|
||||
bandwidth_port = self.port + 1
|
||||
listener = await anyio.create_tcp_listener(local_port=bandwidth_port)
|
||||
logger.info(f"Bandwidth test server listening on port {bandwidth_port}")
|
||||
await listener.serve(self._handle_bandwidth_connection)
|
||||
|
||||
async def _handle_bandwidth_connection(self, stream: SocketStream) -> None:
|
||||
"""Handle a single bandwidth test connection."""
|
||||
try:
|
||||
mode = await stream.receive(1)
|
||||
if mode == b"U":
|
||||
# Upload test: client sends, we receive
|
||||
bytes_received = 0
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
try:
|
||||
data = await stream.receive(1024 * 1024)
|
||||
if not data or data == b"DONE" or b"DONE" in data:
|
||||
break
|
||||
bytes_received += len(data)
|
||||
except EndOfStream:
|
||||
break
|
||||
elapsed = time.perf_counter() - start
|
||||
logger.debug(f"Bandwidth upload: {bytes_received} bytes in {elapsed:.3f}s")
|
||||
elif mode == b"D":
|
||||
# Download test: we send, client receives
|
||||
chunk = b"X" * (1024 * 1024)
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < 1.0: # Send for 1s
|
||||
try:
|
||||
await stream.send(chunk)
|
||||
except (BrokenResourceError, ClosedResourceError):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Bandwidth connection error: {e}")
|
||||
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
|
||||
@@ -28,9 +28,6 @@ def create_node_network() -> NodeNetworkInfo:
|
||||
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||
latency_ms=1.0,
|
||||
other_to_sink_bandwidth_mbps=1000.0,
|
||||
sink_to_other_bandwidth_mbps=1000.0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -366,10 +366,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
ip_address="10.0.0.1",
|
||||
)
|
||||
ethernet_conn = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000"),
|
||||
latency_ms=1.0,
|
||||
other_to_sink_bandwidth_mbps=1000.0,
|
||||
sink_to_other_bandwidth_mbps=1000.0,
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
|
||||
)
|
||||
|
||||
node_network = {
|
||||
|
||||
@@ -15,9 +15,6 @@ def topology() -> Topology:
|
||||
def socket_connection() -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
latency_ms=1.0,
|
||||
sink_to_other_bandwidth_mbps=1000.0,
|
||||
other_to_sink_bandwidth_mbps=1000.0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -216,8 +216,6 @@ def get_node_id_keypair(
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
|
||||
from multiprocessing.synchronize import Semaphore as SemaphoreT
|
||||
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture, mark
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
@@ -74,7 +74,6 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
|
||||
os.remove(p)
|
||||
|
||||
|
||||
@mark.skip(reason="this functionality is currently disabled but may return in future")
|
||||
def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
reps = 10
|
||||
|
||||
|
||||
@@ -16,9 +16,6 @@ def test_state_serialization_roundtrip() -> None:
|
||||
sink=node_b,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
latency_ms=1.5,
|
||||
sink_to_other_bandwidth_mbps=1000.0,
|
||||
other_to_sink_bandwidth_mbps=1100.0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -116,8 +116,8 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: PromptTokensDetails
|
||||
completion_tokens_details: CompletionTokensDetails
|
||||
prompt_tokens_details: PromptTokensDetails | None = None
|
||||
completion_tokens_details: CompletionTokensDetails | None = None
|
||||
|
||||
|
||||
class StreamingChoiceResponse(BaseModel):
|
||||
@@ -170,10 +170,6 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: bool = False
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(TaggedModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
@@ -190,7 +186,6 @@ class ChatCompletionTaskParams(TaggedModel):
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -17,7 +17,6 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -29,7 +28,6 @@ class ErrorChunk(BaseChunk):
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["tool_calls"] = "tool_calls"
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -24,22 +24,10 @@ class RDMAConnection(FrozenModel):
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
latency_ms: float
|
||||
sink_to_other_bandwidth_mbps: float
|
||||
other_to_sink_bandwidth_mbps: float
|
||||
|
||||
@property
|
||||
def bandwidth_mbps(self):
|
||||
return min(self.sink_to_other_bandwidth_mbps, self.other_to_sink_bandwidth_mbps)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SocketConnection):
|
||||
return NotImplemented
|
||||
return self.sink_multiaddr == other.sink_multiaddr
|
||||
|
||||
|
||||
class Connection(FrozenModel):
|
||||
source: NodeId
|
||||
|
||||
@@ -6,7 +6,6 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -25,7 +24,6 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
@@ -59,7 +57,6 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
import time
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import SocketStream
|
||||
from exo.shared.logging import logger
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
LATENCY_PING_COUNT = 5
|
||||
BANDWIDTH_TEST_DURATION_S = 0.5
|
||||
BANDWIDTH_TEST_PORT_OFFSET = 1 # API port + 1
|
||||
|
||||
|
||||
class ConnectionProfile(BaseModel):
|
||||
latency_ms: float
|
||||
upload_mbps: float
|
||||
download_mbps: float
|
||||
|
||||
|
||||
async def measure_latency(target_ip: str, port: int = 52415) -> float:
|
||||
if ":" in target_ip:
|
||||
url = f"http://[{target_ip}]:{port}/node_id"
|
||||
else:
|
||||
url = f"http://{target_ip}:{port}/node_id"
|
||||
|
||||
rtts: list[float] = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
for _ in range(LATENCY_PING_COUNT):
|
||||
try:
|
||||
start = time.perf_counter()
|
||||
response = await client.get(url)
|
||||
end = time.perf_counter()
|
||||
|
||||
if response.status_code == 200:
|
||||
rtts.append((end - start) * 1000)
|
||||
except (httpx.TimeoutException, httpx.NetworkError, httpx.RemoteProtocolError) as e:
|
||||
logger.debug(f"Latency ping failed: {e}")
|
||||
|
||||
if not rtts:
|
||||
raise ConnectionError(f"Failed to measure latency to {target_ip}:{port}")
|
||||
|
||||
return sum(rtts) / len(rtts)
|
||||
|
||||
|
||||
async def _measure_upload_tcp(stream: SocketStream, duration: float) -> float:
|
||||
"""Send data for duration seconds, return Mbps."""
|
||||
chunk = b"X" * (1024 * 1024) # 1MB
|
||||
bytes_sent = 0
|
||||
start = time.perf_counter()
|
||||
deadline = start + duration
|
||||
|
||||
while time.perf_counter() < deadline:
|
||||
await stream.send(chunk)
|
||||
bytes_sent += len(chunk)
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return (bytes_sent * 8 / elapsed) / 1_000_000 if elapsed > 0 else 0.0
|
||||
|
||||
|
||||
async def _measure_download_tcp(stream: SocketStream, duration: float) -> float:
|
||||
"""Receive data for duration seconds, return Mbps."""
|
||||
bytes_received = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
with anyio.move_on_after(duration):
|
||||
while True:
|
||||
data = await stream.receive(1024 * 1024)
|
||||
if not data:
|
||||
break
|
||||
bytes_received += len(data)
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return (bytes_received * 8 / elapsed) / 1_000_000 if elapsed > 0 else 0.0
|
||||
|
||||
|
||||
async def measure_bandwidth_tcp(target_ip: str, port: int) -> tuple[float, float]:
|
||||
"""Measure bandwidth using raw TCP like iperf."""
|
||||
upload_mbps = 0.0
|
||||
download_mbps = 0.0
|
||||
|
||||
try:
|
||||
async with await anyio.connect_tcp(target_ip, port) as stream:
|
||||
# Protocol: send 'U' for upload test, 'D' for download test
|
||||
# Upload: client sends, server receives
|
||||
await stream.send(b"U")
|
||||
upload_mbps = await _measure_upload_tcp(stream, BANDWIDTH_TEST_DURATION_S)
|
||||
await stream.send(b"DONE")
|
||||
logger.debug(f"Upload: {upload_mbps:.1f} Mbps")
|
||||
except Exception as e:
|
||||
logger.debug(f"Upload TCP test failed: {e}")
|
||||
|
||||
try:
|
||||
async with await anyio.connect_tcp(target_ip, port) as stream:
|
||||
# Download: client receives, server sends
|
||||
await stream.send(b"D")
|
||||
download_mbps = await _measure_download_tcp(stream, BANDWIDTH_TEST_DURATION_S)
|
||||
logger.debug(f"Download: {download_mbps:.1f} Mbps")
|
||||
except Exception as e:
|
||||
logger.debug(f"Download TCP test failed: {e}")
|
||||
|
||||
return upload_mbps, download_mbps
|
||||
|
||||
|
||||
async def profile_connection(target_ip: str, port: int = 52415) -> ConnectionProfile:
|
||||
logger.debug(f"Profiling connection to {target_ip}:{port}")
|
||||
|
||||
latency_ms = await measure_latency(target_ip, port)
|
||||
logger.debug(f"Measured latency to {target_ip}: {latency_ms:.2f}ms")
|
||||
|
||||
bandwidth_port = port + BANDWIDTH_TEST_PORT_OFFSET
|
||||
upload_mbps, download_mbps = await measure_bandwidth_tcp(target_ip, bandwidth_port)
|
||||
logger.debug(
|
||||
f"Measured bandwidth to {target_ip}: "
|
||||
f"upload={upload_mbps:.1f}Mbps, download={download_mbps:.1f}Mbps"
|
||||
)
|
||||
|
||||
if upload_mbps == 0.0 and download_mbps == 0.0:
|
||||
raise ConnectionError(f"Failed to measure bandwidth to {target_ip}:{bandwidth_port}")
|
||||
|
||||
return ConnectionProfile(
|
||||
latency_ms=latency_ms,
|
||||
upload_mbps=upload_mbps,
|
||||
download_mbps=download_mbps,
|
||||
)
|
||||
@@ -3,7 +3,6 @@ from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
@@ -13,29 +12,25 @@ from mlx_lm.models.cache import (
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.85
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
):
|
||||
def __init__(self, tokenizer: TokenizerWrapper):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
@@ -86,13 +81,13 @@ class KVPrefixCache:
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
@@ -114,7 +109,7 @@ class KVPrefixCache:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
@@ -136,37 +131,29 @@ class KVPrefixCache:
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
"""Evict least recently used entries while memory pressure is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
active: int = mx.metal.get_active_memory()
|
||||
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
while len(self.caches) > 0:
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
|
||||
)
|
||||
|
||||
def get_memory_used_percentage(self) -> float:
|
||||
local_pressure: float = get_memory_used_percentage()
|
||||
|
||||
if self._group is None:
|
||||
return local_pressure
|
||||
|
||||
all_pressure = mx.distributed.all_gather(
|
||||
mx.array([local_pressure], dtype=mx.float32),
|
||||
group=self._group,
|
||||
)
|
||||
# .item() evals.
|
||||
max_pressure = float(mx.max(all_pressure).item())
|
||||
return max_pressure
|
||||
active = mx.metal.get_active_memory()
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
break
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
@@ -181,13 +168,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
@@ -198,17 +185,6 @@ def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def get_available_memory() -> Memory:
|
||||
mem: int = psutil.virtual_memory().available
|
||||
return Memory.from_bytes(mem)
|
||||
|
||||
|
||||
def get_memory_used_percentage() -> float:
|
||||
mem = psutil.virtual_memory()
|
||||
# percent is 0-100
|
||||
return float(mem.percent / 100)
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
|
||||
@@ -10,11 +10,8 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
CompletionTokensDetails,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
PromptTokensDetails,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
@@ -219,37 +216,22 @@ def mlx_generate(
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
usage: Usage | None = None
|
||||
in_thinking = False
|
||||
reasoning_tokens = 0
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
for completion_tokens, out in enumerate(
|
||||
stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
),
|
||||
start=1,
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
in_thinking = True
|
||||
elif think_end is not None and out.text == think_end:
|
||||
in_thinking = False
|
||||
if in_thinking:
|
||||
reasoning_tokens += 1
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
@@ -267,24 +249,11 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
completion_tokens_details=CompletionTokensDetails(
|
||||
reasoning_tokens=reasoning_tokens
|
||||
),
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
|
||||
@@ -44,7 +44,6 @@ from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.connection_profiler import profile_connection
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
@@ -66,7 +65,6 @@ class Worker:
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
profiling_in_progress: set[NodeId] | None = None,
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
@@ -83,8 +81,6 @@ class Worker:
|
||||
self.state: State = State()
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
self._profiling_in_progress: set[NodeId] = profiling_in_progress if profiling_in_progress is not None else set()
|
||||
self._profiling_lock = anyio.Lock()
|
||||
|
||||
self._nack_cancel_scope: CancelScope | None = None
|
||||
self._nack_attempts: int = 0
|
||||
@@ -286,73 +282,37 @@ class Worker:
|
||||
async def _connection_message_event_writer(self):
|
||||
with self.connection_message_receiver as connection_messages:
|
||||
async for msg in connection_messages:
|
||||
event = await self._convert_connection_message_to_event(msg)
|
||||
if event:
|
||||
await self.event_sender.send(event)
|
||||
await self.event_sender.send(
|
||||
self._convert_connection_message_to_event(msg)
|
||||
)
|
||||
|
||||
async def _convert_connection_message_to_event(self, msg: ConnectionMessage):
|
||||
if msg.connection_type == ConnectionMessageType.Connected:
|
||||
async with self._profiling_lock:
|
||||
if msg.node_id in self._profiling_in_progress:
|
||||
return None
|
||||
self._profiling_in_progress.add(msg.node_id)
|
||||
return await self._profile_and_emit_connection(
|
||||
msg.node_id,
|
||||
msg.remote_ipv4,
|
||||
msg.remote_tcp_port,
|
||||
)
|
||||
elif msg.connection_type == ConnectionMessageType.Disconnected:
|
||||
self._profiling_in_progress.discard(msg.node_id)
|
||||
target_ip = msg.remote_ipv4
|
||||
for connection in self.state.topology.list_connections():
|
||||
if (
|
||||
isinstance(connection.edge, SocketConnection)
|
||||
and connection.edge.sink_multiaddr.ip_address == target_ip
|
||||
):
|
||||
return TopologyEdgeDeleted(conn=connection)
|
||||
|
||||
async def _profile_and_emit_connection(
|
||||
self,
|
||||
sink_node_id: NodeId,
|
||||
remote_ip: str,
|
||||
remote_port: int,
|
||||
):
|
||||
try:
|
||||
profile = await profile_connection(remote_ip)
|
||||
except ConnectionError as e:
|
||||
logger.warning(f"Failed to profile connection to {sink_node_id}: {e}")
|
||||
profile = None
|
||||
|
||||
if profile:
|
||||
latency_ms = profile.latency_ms
|
||||
other_to_sink_mbps = profile.upload_mbps
|
||||
sink_to_other_mbps = profile.download_mbps
|
||||
else:
|
||||
latency_ms = 0.0
|
||||
other_to_sink_mbps = 0.0
|
||||
sink_to_other_mbps = 0.0
|
||||
|
||||
logger.info(
|
||||
f"Connection to {sink_node_id} profiled: "
|
||||
f"latency={latency_ms:.2f}ms, "
|
||||
f"upload={other_to_sink_mbps:.1f}Mbps, "
|
||||
f"download={sink_to_other_mbps:.1f}Mbps"
|
||||
)
|
||||
|
||||
return TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=sink_node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{remote_ip}/tcp/{remote_port}"
|
||||
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
return TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
),
|
||||
latency_ms=latency_ms,
|
||||
sink_to_other_bandwidth_mbps=sink_to_other_mbps,
|
||||
other_to_sink_bandwidth_mbps=other_to_sink_mbps,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
return TopologyEdgeDeleted(
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -429,21 +389,21 @@ class Worker:
|
||||
)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
temp_edge = SocketConnection(
|
||||
edge = SocketConnection(
|
||||
# nonsense multiaddr
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
if "." in ip
|
||||
# nonsense multiaddr
|
||||
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
||||
latency_ms=0.0,
|
||||
sink_to_other_bandwidth_mbps=0.0,
|
||||
other_to_sink_bandwidth_mbps=0.0,
|
||||
)
|
||||
if temp_edge not in edges:
|
||||
logger.debug(f"ping discovered new connection to {nid} at {ip}")
|
||||
self._tg.start_soon(
|
||||
self._profile_and_emit_connection,
|
||||
nid,
|
||||
ip,
|
||||
52415,
|
||||
if edge not in edges:
|
||||
logger.debug(f"ping discovered {edge=}")
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id, sink=nid, edge=edge
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for conn in self.state.topology.out_edges(self.node_id):
|
||||
|
||||
@@ -37,7 +37,6 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -112,12 +111,8 @@ def main(
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
seen = set[TaskId]()
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -168,7 +163,7 @@ def main(
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
@@ -282,11 +277,9 @@ def main(
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
@@ -314,7 +307,6 @@ def main(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
@@ -328,7 +320,6 @@ def main(
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -544,8 +535,7 @@ def parse_gpt_oss(
|
||||
name=current_tool_name,
|
||||
arguments="".join(tool_arg_parts).strip(),
|
||||
)
|
||||
],
|
||||
usage=response.usage,
|
||||
]
|
||||
)
|
||||
tool_arg_parts = []
|
||||
current_tool_name = recipient
|
||||
@@ -693,7 +683,7 @@ def parse_tool_calls(
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
|
||||
@@ -127,25 +127,20 @@ class RunnerSupervisor:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
f"Skipping invalid task {task} as it has already been submitted"
|
||||
)
|
||||
return
|
||||
if task.task_id in self.completed:
|
||||
logger.warning(
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
try:
|
||||
await self._task_sender.send_async(task)
|
||||
self._task_sender.send(task)
|
||||
except ClosedResourceError:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
return
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
|
||||
@@ -14,9 +14,9 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
cache_length,
|
||||
_cache_length,
|
||||
_get_prefix_length,
|
||||
encode_prompt,
|
||||
get_prefix_length,
|
||||
make_kv_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
|
||||
@@ -35,47 +35,47 @@ class TestGetPrefixLength:
|
||||
def test_identical_arrays(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert get_prefix_length(a, b) == 5
|
||||
assert _get_prefix_length(a, b) == 5
|
||||
|
||||
def test_no_common_prefix(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([4, 5, 6])
|
||||
assert get_prefix_length(a, b) == 0
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_partial_prefix(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 7, 8])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_prompt_longer_than_cached(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_cached_longer_than_prompt(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_single_token_match(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 5, 6])
|
||||
assert get_prefix_length(a, b) == 1
|
||||
assert _get_prefix_length(a, b) == 1
|
||||
|
||||
def test_empty_prompt(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([1, 2, 3])
|
||||
assert get_prefix_length(a, b) == 0
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_empty_cached(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert get_prefix_length(a, b) == 0
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_both_empty(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert get_prefix_length(a, b) == 0
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
|
||||
class TestKVPrefix:
|
||||
@@ -146,7 +146,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert cache_length(cache) == len(tokens)
|
||||
assert _cache_length(cache) == len(tokens)
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
@@ -166,7 +166,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
assert stored_length > 0
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
@@ -209,7 +209,7 @@ class TestKVPrefixCacheWithModel:
|
||||
long_tokens = encode_prompt(tokenizer, long_prompt)
|
||||
|
||||
# The prompts share a prefix (chat template preamble + "Hi")
|
||||
expected_prefix = get_prefix_length(long_tokens, short_tokens)
|
||||
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
|
||||
assert expected_prefix > 0, (
|
||||
"Prompts should share a prefix from the chat template"
|
||||
)
|
||||
@@ -243,7 +243,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
@@ -259,7 +259,7 @@ class TestKVPrefixCacheWithModel:
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
# Stored cache must be unchanged
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
|
||||
def test_stored_cache_survives_repeated_get_mutate_cycles(
|
||||
self, model_and_tokenizer
|
||||
@@ -281,7 +281,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
@@ -293,7 +293,7 @@ class TestKVPrefixCacheWithModel:
|
||||
layer_cache.update_and_fetch(extra, extra)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
f"Failed on loop {i}"
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ class TestKVPrefixCacheWithModel:
|
||||
assert len(kv_prefix_cache.caches) == 1
|
||||
# Cache should contain prompt + generated tokens
|
||||
expected_length = len(prompt_tokens) + generated_tokens
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
|
||||
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
@@ -400,7 +400,7 @@ class TestKVPrefixCacheWithModel:
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
first_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation: same long prompt + extra content (simulating multi-turn)
|
||||
task2 = ChatCompletionTaskParams(
|
||||
@@ -416,7 +416,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt2_tokens = encode_prompt(tokenizer, prompt2)
|
||||
|
||||
# Verify the prompts share a long prefix
|
||||
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
|
||||
|
||||
# Second generation should reuse the cached prefix (only prefill new tokens)
|
||||
@@ -440,7 +440,7 @@ class TestKVPrefixCacheWithModel:
|
||||
# With prefix_hit > 1000, should update in-place (not add a second entry)
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
|
||||
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
assert updated_cache_length > first_cache_length
|
||||
|
||||
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
|
||||
@@ -465,7 +465,7 @@ class TestKVPrefixCacheWithModel:
|
||||
):
|
||||
pass
|
||||
|
||||
firstcache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation gets the cache and mutates it during generation
|
||||
for _response in mlx_generate(
|
||||
@@ -478,7 +478,7 @@ class TestKVPrefixCacheWithModel:
|
||||
pass
|
||||
|
||||
# The first stored cache must not have been mutated by the second generation
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
|
||||
|
||||
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
@@ -540,6 +540,6 @@ class TestKVPrefixCacheWithModel:
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
|
||||
@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
@@ -120,7 +120,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
@@ -147,14 +147,6 @@ class MockTokenizer:
|
||||
has_tool_calling = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -190,8 +182,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
stats=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
import multiprocessing as mp
|
||||
import socket
|
||||
import time
|
||||
import typing
|
||||
from typing import Literal
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from hypercorn import Config
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -36,9 +31,14 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.channels import channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
@@ -46,37 +46,36 @@ from exo.worker.runner.bootstrap import entrypoint
|
||||
class Tests(BaseModel):
|
||||
# list[hostname, ip addr]
|
||||
devs: list[list[str]]
|
||||
model_id: str
|
||||
kind: typing.Literal["init", "warmup", "inference"]
|
||||
rdma_devs: list[list[str | None]] | None
|
||||
model_id: ModelId
|
||||
kind: Literal["ring", "rdma", "both"]
|
||||
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
logger_setup(None)
|
||||
iid = InstanceId("im testing here")
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
await assert_downloads()
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52415"
|
||||
cfg.bind = "0.0.0.0:52414"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = "-"
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
ev = anyio.Event()
|
||||
app = FastAPI()
|
||||
app.post("/ring")(ring_backend)
|
||||
app.post("/jaccl")(jaccl_backend)
|
||||
app.post("/tb_detection")(tb_detection)
|
||||
shutdown = anyio.Event()
|
||||
app.post("/run_test")(run_test)
|
||||
app.post("/kill")(lambda: kill(ev))
|
||||
app.get("/tb_detection")(tb_detection)
|
||||
app.get("/models")(list_models)
|
||||
await serve(
|
||||
app, # type: ignore
|
||||
cfg,
|
||||
shutdown_trigger=lambda: shutdown.wait(),
|
||||
shutdown_trigger = lambda: ev.wait()
|
||||
)
|
||||
await anyio.sleep_forever()
|
||||
# gracefully shutdown the api
|
||||
shutdown.set()
|
||||
|
||||
def kill(ev: anyio.Event):
|
||||
ev.set()
|
||||
return Response(status_code=204)
|
||||
|
||||
async def tb_detection():
|
||||
send, recv = channel[GatheredInfo]()
|
||||
@@ -87,29 +86,19 @@ async def tb_detection():
|
||||
return recv.collect()
|
||||
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
def list_models():
|
||||
sent = set[str]()
|
||||
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
|
||||
if "--" not in path.parent.name:
|
||||
continue
|
||||
name = path.parent.name.replace("--", "/")
|
||||
if name in sent:
|
||||
continue
|
||||
sent.add(name)
|
||||
yield ModelId(path.parent.name.replace("--", "/"))
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
async def run_test(test: Tests):
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
@@ -117,31 +106,67 @@ async def ring_backend(test: Tests):
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, ring_instance(test, iid, hn), hn)
|
||||
|
||||
async def run():
|
||||
logger.info(f"testing {test.model_id}")
|
||||
|
||||
instances: list[Instance] = []
|
||||
if test.kind in ["ring", "both"]:
|
||||
i = ring_instance(test, hn)
|
||||
if i is None:
|
||||
yield "no model found"
|
||||
return
|
||||
instances.append(i)
|
||||
if test.kind in ["rdma", "both"]:
|
||||
i = jaccl_instance(test)
|
||||
if i is None:
|
||||
yield "no model found"
|
||||
return
|
||||
instances.append(i)
|
||||
|
||||
for instance in instances:
|
||||
recv = await execute_test(test, instance, hn)
|
||||
|
||||
str_out = ""
|
||||
|
||||
for item in recv:
|
||||
if isinstance(item, ChunkGenerated):
|
||||
assert isinstance(item.chunk, TokenChunk)
|
||||
str_out += item.chunk.text
|
||||
|
||||
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
||||
item.runner_status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
yield str_out + "\n"
|
||||
yield item.model_dump_json() + "\n"
|
||||
|
||||
return StreamingResponse(run())
|
||||
|
||||
|
||||
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
|
||||
def ring_instance(test: Tests, hn: str) -> Instance | None:
|
||||
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
if test.devs[i][0] == hn:
|
||||
hn = test.devs[i][0]
|
||||
if i - 1 >= 0:
|
||||
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
|
||||
if i + 1 < len(test.devs):
|
||||
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
|
||||
hbn[i] = Host(ip="0.0.0.0", port=52416)
|
||||
break
|
||||
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
|
||||
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
|
||||
hbn[i] = Host(ip="0.0.0.0", port=52417)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
|
||||
)
|
||||
if card is None:
|
||||
return None
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52416,
|
||||
ephemeral_port=52417,
|
||||
hosts_by_node={NodeId(hn): hbn},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
model_id=test.model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
||||
@@ -163,119 +188,86 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
|
||||
world_size = len(test.devs)
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
_handle, recv, send = new_runner(instance, hn)
|
||||
if world_size > 1:
|
||||
send.send(ConnectToGroup(instance_id=iid))
|
||||
send.send(LoadModel(instance_id=iid))
|
||||
|
||||
match test.kind:
|
||||
case "init":
|
||||
pass
|
||||
case "warmup":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
case "inference":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
commands: list[Task] = [
|
||||
(LoadModel(instance_id=iid)),
|
||||
(StartWarmup(instance_id=iid)),
|
||||
(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
max_tokens=50,
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
),
|
||||
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
|
||||
]
|
||||
if world_size > 1:
|
||||
commands.insert(0, ConnectToGroup(instance_id=iid))
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
||||
)
|
||||
ev_send, _ev_recv = mp_channel[Event]()
|
||||
task_send, task_recv = mp_channel[Task]()
|
||||
|
||||
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
|
||||
for command in commands:
|
||||
task_send.send(command)
|
||||
|
||||
async def map_recv():
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
yield item.model_dump_json() + "\n"
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
entrypoint(
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
logger, # type: ignore
|
||||
)
|
||||
|
||||
ret = StreamingResponse(map_recv())
|
||||
ret._pls_dont_gc = _handle # type: ignore
|
||||
return ret
|
||||
# TODO(evan): return ev_recv.collect()
|
||||
return []
|
||||
|
||||
|
||||
async def jaccl_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid), hn)
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
|
||||
card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
|
||||
)
|
||||
if card is None:
|
||||
return None
|
||||
world_size = len(test.devs)
|
||||
assert test.rdma_devs
|
||||
|
||||
return MlxJacclInstance(
|
||||
instance_id=iid,
|
||||
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
|
||||
jaccl_devices=test.rdma_devs,
|
||||
# rank 0 is always coordinator
|
||||
jaccl_coordinators={
|
||||
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
|
||||
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
|
||||
},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
model_id=test.model_id,
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): TensorShardMetadata(
|
||||
RunnerId(host[0]): TensorShardMetadata(
|
||||
model_card=card,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=card.n_layers,
|
||||
start_layer=0,
|
||||
end_layer=card.n_layers,
|
||||
n_layers=card.n_layers,
|
||||
)
|
||||
for i in range(world_size)
|
||||
for i, host in enumerate(test.devs)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def new_runner(
|
||||
instance: Instance,
|
||||
hn: str,
|
||||
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
||||
)
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
task_send, task_recv = mp_channel[Task]()
|
||||
|
||||
runner_process = mp.Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
logger,
|
||||
),
|
||||
)
|
||||
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
|
||||
runner_process.start()
|
||||
time.sleep(0.1)
|
||||
return (runner_process, ev_recv, task_send)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main)
|
||||
|
||||
50
tests/run_distributed_test.sh
Executable file
50
tests/run_distributed_test.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
[ $# -eq 0 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
for host; do
|
||||
curl -m 1 -X POST "http://$host:52414/kill" >/dev/null 2>&1 || true &
|
||||
done
|
||||
wait
|
||||
|
||||
echo "Deploying $commit to $# hosts..."
|
||||
|
||||
pids=""
|
||||
trap 'xargs -r kill 2>/dev/null <<<"$pids" || true' EXIT INT TERM
|
||||
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
|
||||
reset=$'\e[0m'
|
||||
i=0
|
||||
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -tt -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" "/usr/bin/env bash -lc '
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q $commit
|
||||
nix develop -c uv sync
|
||||
.venv/bin/python tests/headless_runner.py
|
||||
'" 2>&1 | sed -u "s/^/${colour}[${host}]${reset}/" &
|
||||
pids+=" $!"
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..."
|
||||
until curl -sf "http://$host:52414/models"; do sleep 1; done
|
||||
done
|
||||
|
||||
uv run tests/start_distributed_test.py "$@"
|
||||
85
tests/start_distributed_test.py
Executable file
85
tests/start_distributed_test.py
Executable file
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
import itertools
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, cast
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
if not (args := sys.argv[1:]):
|
||||
sys.exit(
|
||||
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be rdma or ring"
|
||||
)
|
||||
|
||||
kind = args[0] if args[0] in ("rdma", "ring") else "both"
|
||||
hosts = args[1:] if kind != "both" else args
|
||||
ts = subprocess.run(
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
|
||||
ips = [ip[h] for h in hosts]
|
||||
devs = [[h, ip[h]] for h in hosts]
|
||||
n = len(hosts)
|
||||
|
||||
|
||||
def get_tb(a: str) -> list[dict[str, Any]]:
|
||||
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
|
||||
return json.loads(r.read()) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def get_models(a: str) -> set[str]:
|
||||
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
|
||||
return set(json.loads(r.read())) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def run(h: str, a: str, body: bytes) -> None:
|
||||
with urlopen(
|
||||
Request(
|
||||
f"http://{a}:52414/run_test",
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
),
|
||||
timeout=300,
|
||||
) as r: # pyright: ignore[reportAny]
|
||||
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
|
||||
print(f"\n{h}@{a}: {line}", flush=True)
|
||||
|
||||
|
||||
with ThreadPoolExecutor(n) as exctr:
|
||||
if kind in ("rdma", "both"):
|
||||
payloads = list(exctr.map(get_tb, ips))
|
||||
|
||||
u2e = {
|
||||
ident["domainUuid"]: (i, ident["rdmaInterface"])
|
||||
for i, p in enumerate(payloads)
|
||||
for d in p
|
||||
for ident in cast(
|
||||
list[dict[str, str]],
|
||||
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
|
||||
)
|
||||
}
|
||||
edges = {
|
||||
(u2e[s][0], u2e[t][0]): u2e[t][1]
|
||||
for p in payloads
|
||||
for d in p
|
||||
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
|
||||
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
|
||||
}
|
||||
rdma_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
|
||||
else:
|
||||
rdma_devs = None
|
||||
|
||||
models = set[str].intersection(*exctr.map(get_models, ips))
|
||||
|
||||
print("\n")
|
||||
print("=" * 70)
|
||||
print(f"Starting test with {models}")
|
||||
print("=" * 70)
|
||||
print("\n")
|
||||
for model in models:
|
||||
body = json.dumps(
|
||||
{"devs": devs, "model_id": model, "rdma_devs": rdma_devs, "kind": kind}
|
||||
).encode()
|
||||
list(exctr.map(run, hosts, ips, itertools.repeat(body)))
|
||||
@@ -1,54 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
query() {
|
||||
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
|
||||
}
|
||||
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "USAGE: $0 <test kind> [host1] [host2] ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
kind=$1
|
||||
shift
|
||||
|
||||
test_kinds="ring jaccl"
|
||||
|
||||
if ! echo "$test_kinds" | grep -q "$kind"; then
|
||||
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
hostnames=("$@")
|
||||
weaved=()
|
||||
ips=()
|
||||
for name in "${hostnames[@]}"; do
|
||||
ip=$(query "$name")
|
||||
ips+=("$ip")
|
||||
weaved+=("$name" "$ip")
|
||||
done
|
||||
|
||||
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
done
|
||||
39
tmp/run_exo_on.sh
Executable file
39
tmp/run_exo_on.sh
Executable file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
[ $# -eq 0 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
echo "Deploying $commit to $# hosts..."
|
||||
|
||||
pids=""
|
||||
trap 'xargs -r kill 2>/dev/null <<<"$pids" || true' EXIT INT TERM
|
||||
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
|
||||
reset=$'\e[0m'
|
||||
i=0
|
||||
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -t -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" "/usr/bin/env bash -lc '
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q $commit
|
||||
nix develop -c uv sync
|
||||
uv run exo
|
||||
'" 2>&1 | sed -u "s/^/${colour}[${host}]${reset}/" &
|
||||
pids+=" $!"
|
||||
done
|
||||
wait
|
||||
Reference in New Issue
Block a user