Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
c3753f8631 yay 2026-01-30 11:53:18 +00:00
17 changed files with 414 additions and 385 deletions

View File

@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.load(model_id)
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -166,8 +166,9 @@ class ResumableShardDownloader(ShardDownloader):
for task in asyncio.as_completed(tasks):
try:
yield await task
# TODO: except Exception
except Exception as e:
logger.warning(f"Error downloading shard: {type(e).__name__}")
logger.error("Error downloading shard:", e)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -65,9 +65,7 @@ from exo.shared.types.api import (
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
StreamOptions,
ToolCall,
Usage,
)
from exo.shared.types.chunks import (
ErrorChunk,
@@ -115,9 +113,7 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk,
command_id: CommandId,
usage: Usage | None,
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -142,10 +138,21 @@ def chunk_to_response(
finish_reason=chunk.finish_reason,
)
],
usage=usage,
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -267,7 +274,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await ModelCard.load(payload.model_id),
model_card=await resolve_model_card(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -284,7 +291,7 @@ class API:
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await ModelCard.load(instance.shard_assignments.model_id)
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
@@ -312,7 +319,7 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await ModelCard.load(model_id)
model_card = await resolve_model_card(model_id)
try:
placements = get_instance_placements(
@@ -515,10 +522,9 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, stream_options: StreamOptions | None = None
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
include_usage = stream_options.include_usage if stream_options else False
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
@@ -534,10 +540,8 @@ class API:
yield "data: [DONE]\n\n"
return
usage = chunk.usage if include_usage else None
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id, usage=usage
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
@@ -553,9 +557,8 @@ class API:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: ModelId | None = None
model: str | None = None
finish_reason: FinishReason | None = None
usage: Usage | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ErrorChunk):
@@ -580,9 +583,6 @@ class API:
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.usage is not None:
usage = chunk.usage
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -604,7 +604,6 @@ class API:
finish_reason=finish_reason,
)
],
usage=usage,
)
async def _collect_chat_completion_with_stats(
@@ -612,7 +611,7 @@ class API:
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: ModelId | None = None
model: str | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
@@ -665,7 +664,7 @@ class API:
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
)
@@ -674,7 +673,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await ModelCard.load(ModelId(payload.model))
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -692,7 +691,7 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, payload.stream_options),
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
@@ -701,7 +700,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await ModelCard.load(ModelId(payload.model))
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -721,12 +720,12 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: ModelId) -> ModelId:
async def _validate_image_model(self, model: str) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await ModelCard.load(model)
model_card = await resolve_model_card(ModelId(model))
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
@@ -772,7 +771,7 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.model = await self._validate_image_model(payload.model)
command = ImageGeneration(
request_params=payload,
@@ -1017,7 +1016,7 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.model = await self._validate_image_model(payload.model)
payload.stream = False
payload.partial_images = 0
@@ -1038,7 +1037,7 @@ class API:
self,
image: UploadFile,
prompt: str,
model: ModelId,
model: str,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1133,7 +1132,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=ModelId(model),
model=model,
n=n,
size=size,
response_format=response_format,
@@ -1189,7 +1188,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=ModelId(model),
model=model,
n=n,
size=size,
response_format=response_format,

View File

@@ -216,8 +216,6 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

View File

@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from loguru import logger
from pytest import LogCaptureFixture, mark
from pytest import LogCaptureFixture
from exo.routing.router import get_node_id_keypair
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
@@ -74,7 +74,6 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
os.remove(p)
@mark.skip(reason="this functionality is currently disabled but may return in future")
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10

View File

@@ -3,7 +3,7 @@ from collections.abc import Generator
from typing import Annotated, Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -116,8 +116,8 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: PromptTokensDetails
completion_tokens_details: CompletionTokensDetails
prompt_tokens_details: PromptTokensDetails | None = None
completion_tokens_details: CompletionTokensDetails | None = None
class StreamingChoiceResponse(BaseModel):
@@ -170,10 +170,6 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class StreamOptions(BaseModel):
include_usage: bool = False
class ChatCompletionTaskParams(TaggedModel):
model_config = ConfigDict(extra="ignore")
@@ -190,7 +186,6 @@ class ChatCompletionTaskParams(TaggedModel):
seed: int | None = None
stop: str | list[str] | None = None
stream: bool = False
stream_options: StreamOptions | None = None
temperature: float | None = None
top_p: float | None = None
tools: list[dict[str, Any]] | None = None

View File

@@ -2,7 +2,7 @@ from collections.abc import Generator
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -17,7 +17,6 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
usage: Usage | None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
@@ -29,7 +28,6 @@ class ErrorChunk(BaseChunk):
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
usage: Usage | None
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None

View File

@@ -6,7 +6,6 @@ from exo.shared.types.api import (
GenerationStats,
ImageGenerationStats,
ToolCallItem,
Usage,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -25,7 +24,6 @@ class GenerationResponse(BaseRunnerResponse):
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
usage: Usage | None
class ImageGenerationResponse(BaseRunnerResponse):
@@ -59,7 +57,6 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -3,7 +3,6 @@ from copy import deepcopy
from typing import Any, cast
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
@@ -13,29 +12,25 @@ from mlx_lm.models.cache import (
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.9
_DEFAULT_MEMORY_THRESHOLD = 0.85
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
class KVPrefixCache:
def __init__(
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
):
def __init__(self, tokenizer: TokenizerWrapper):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
self._group = group
def clear(self):
"""Clear all cached prompts and caches."""
@@ -86,13 +81,13 @@ class KVPrefixCache:
best_snapshot_index, best_snapshot_length = None, 0
for i, cached_prompt in enumerate(self.prompts):
length = get_prefix_length(tokenized_prompt, cached_prompt)
length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = cache_length(self.caches[i])
cached_length = _cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -114,7 +109,7 @@ class KVPrefixCache:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = cache_length(self.caches[best_snapshot_index])
cached_length = _cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -136,37 +131,29 @@ class KVPrefixCache:
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory usage is high."""
"""Evict least recently used entries while memory pressure is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while (
len(self.caches) > 1
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
while len(self.caches) > 0:
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
)
def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_used_percentage()
if self._group is None:
return local_pressure
all_pressure = mx.distributed.all_gather(
mx.array([local_pressure], dtype=mx.float32),
group=self._group,
)
# .item() evals.
max_pressure = float(mx.max(all_pressure).item())
return max_pressure
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_THRESHOLD:
break
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
@@ -181,13 +168,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(tokenized_prompt)
def cache_length(cache: KVCacheType) -> int:
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
@@ -198,17 +185,6 @@ def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
return int(mx.sum(prefix_mask).item())
def get_available_memory() -> Memory:
mem: int = psutil.virtual_memory().available
return Memory.from_bytes(mem)
def get_memory_used_percentage() -> float:
mem = psutil.virtual_memory()
# percent is 0-100
return float(mem.percent / 100)
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> KVCacheType:

View File

@@ -10,11 +10,8 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
CompletionTokensDetails,
FinishReason,
GenerationStats,
PromptTokensDetails,
Usage,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
@@ -219,37 +216,22 @@ def mlx_generate(
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
usage: Usage | None = None
in_thinking = False
reasoning_tokens = 0
think_start = tokenizer.think_start
think_end = tokenizer.think_end
for completion_tokens, out in enumerate(
stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
),
start=1,
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
generated_text_parts.append(out.text)
logger.info(out.text)
if think_start is not None and out.text == think_start:
in_thinking = True
elif think_end is not None and out.text == think_end:
in_thinking = False
if in_thinking:
reasoning_tokens += 1
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
@@ -267,24 +249,11 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
usage = Usage(
prompt_tokens=int(out.prompt_tokens),
completion_tokens=completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),
completion_tokens_details=CompletionTokensDetails(
reasoning_tokens=reasoning_tokens
),
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
usage=usage,
)
if out.finish_reason is not None:

View File

@@ -163,7 +163,7 @@ def main(
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
kv_prefix_cache = KVPrefixCache(tokenizer, group)
kv_prefix_cache = KVPrefixCache(tokenizer)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
@@ -277,11 +277,9 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
completion_tokens = 0
for response in mlx_generator:
match response:
case GenerationResponse():
completion_tokens += 1
if (
device_rank == 0
and response.finish_reason == "error"
@@ -309,7 +307,6 @@ def main(
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
usage=response.usage,
finish_reason=response.finish_reason,
stats=response.stats,
),
@@ -323,7 +320,6 @@ def main(
chunk=ToolCallChunk(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
usage=response.usage,
),
)
)
@@ -539,8 +535,7 @@ def parse_gpt_oss(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
],
usage=response.usage,
]
)
tool_arg_parts = []
current_tool_name = recipient
@@ -688,7 +683,7 @@ def parse_tool_calls(
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
yield ToolCallResponse(tool_calls=tools)
except (
json.JSONDecodeError,

View File

@@ -14,9 +14,9 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
cache_length,
_cache_length,
_get_prefix_length,
encode_prompt,
get_prefix_length,
make_kv_cache,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
@@ -35,47 +35,47 @@ class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 5
assert _get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert get_prefix_length(a, b) == 0
assert _get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert get_prefix_length(a, b) == 3
assert _get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert get_prefix_length(a, b) == 3
assert _get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 3
assert _get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert get_prefix_length(a, b) == 1
assert _get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert get_prefix_length(a, b) == 0
assert _get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert get_prefix_length(a, b) == 0
assert _get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert get_prefix_length(a, b) == 0
assert _get_prefix_length(a, b) == 0
class TestKVPrefix:
@@ -146,7 +146,7 @@ class TestKVPrefixCacheWithModel:
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert cache_length(cache) == len(tokens)
assert _cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
@@ -166,7 +166,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache.add_kv_cache(prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = cache_length(kv_prefix_cache.caches[0])
stored_length = _cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
@@ -209,7 +209,7 @@ class TestKVPrefixCacheWithModel:
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = get_prefix_length(long_tokens, short_tokens)
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
@@ -243,7 +243,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = cache_length(kv_prefix_cache.caches[0])
stored_length = _cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -259,7 +259,7 @@ class TestKVPrefixCacheWithModel:
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
@@ -281,7 +281,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = cache_length(kv_prefix_cache.caches[0])
stored_length = _cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -293,7 +293,7 @@ class TestKVPrefixCacheWithModel:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
@@ -325,7 +325,7 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
@@ -400,7 +400,7 @@ class TestKVPrefixCacheWithModel:
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = cache_length(kv_prefix_cache.caches[0])
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
@@ -416,7 +416,7 @@ class TestKVPrefixCacheWithModel:
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
@@ -440,7 +440,7 @@ class TestKVPrefixCacheWithModel:
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
@@ -465,7 +465,7 @@ class TestKVPrefixCacheWithModel:
):
pass
firstcache_length = cache_length(kv_prefix_cache.caches[0])
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
@@ -478,7 +478,7 @@ class TestKVPrefixCacheWithModel:
pass
# The first stored cache must not have been mutated by the second generation
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
@@ -540,6 +540,6 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

View File

@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
@@ -120,7 +120,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
@@ -147,14 +147,6 @@ class MockTokenizer:
has_tool_calling = False
class MockGroup:
def rank(self) -> int:
return 0
def size(self) -> int:
return 1
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -190,8 +182,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
text="hi",
token_id=0,
finish_reason="stop",
usage=None,
stats=None,
),
)

View File

@@ -1,26 +1,21 @@
import multiprocessing as mp
import socket
import time
import typing
from typing import Literal
import anyio
from loguru import logger
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, Response
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -36,9 +31,14 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.channels import channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
@@ -46,37 +46,36 @@ from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
rdma_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "rdma", "both"]
mp.set_start_method("spawn", force=True)
logger_setup(None)
iid = InstanceId("im testing here")
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
cfg.bind = "0.0.0.0:52414"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
ev = anyio.Event()
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
app.post("/run_test")(run_test)
app.post("/kill")(lambda: kill(ev))
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: shutdown.wait(),
shutdown_trigger = lambda: ev.wait()
)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
def kill(ev: anyio.Event):
ev.set()
return Response(status_code=204)
async def tb_detection():
send, recv = channel[GatheredInfo]()
@@ -87,29 +86,19 @@ async def tb_detection():
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
async def run_test(test: Tests):
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
@@ -117,31 +106,67 @@ async def ring_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["rdma", "both"]:
i = jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
if i + 1 < len(test.devs):
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
hbn[i] = Host(ip="0.0.0.0", port=52417)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
ephemeral_port=52417,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -163,119 +188,86 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str):
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
max_tokens=50,
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
if world_size > 1:
commands.insert(0, ConnectToGroup(instance_id=iid))
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, _ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
for command in commands:
task_send.send(command)
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
entrypoint(
bound_instance,
ev_send,
task_recv,
logger, # type: ignore
)
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
# TODO(evan): return ev_recv.collect()
return []
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
world_size = len(test.devs)
assert test.rdma_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
jaccl_devices=test.rdma_devs,
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
RunnerId(host[0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
for i, host in enumerate(test.devs)
},
),
)
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
logger,
),
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()
time.sleep(0.1)
return (runner_process, ev_recv, task_send)
if __name__ == "__main__":
anyio.run(main)

50
tests/run_distributed_test.sh Executable file
View File

@@ -0,0 +1,50 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -eq 0 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
for host; do
curl -m 1 -X POST "http://$host:52414/kill" >/dev/null 2>&1 || true &
done
wait
echo "Deploying $commit to $# hosts..."
pids=""
trap 'xargs -r kill 2>/dev/null <<<"$pids" || true' EXIT INT TERM
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
ssh -tt -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" "/usr/bin/env bash -lc '
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q $commit
nix develop -c uv sync
.venv/bin/python tests/headless_runner.py
'" 2>&1 | sed -u "s/^/${colour}[${host}]${reset}/" &
pids+=" $!"
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52414/models"; do sleep 1; done
done
uv run tests/start_distributed_test.py "$@"

85
tests/start_distributed_test.py Executable file
View File

@@ -0,0 +1,85 @@
#!/usr/bin/env python3
import itertools
import json
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, cast
from urllib.request import Request, urlopen
if not (args := sys.argv[1:]):
sys.exit(
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be rdma or ring"
)
kind = args[0] if args[0] in ("rdma", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]
devs = [[h, ip[h]] for h in hosts]
n = len(hosts)
def get_tb(a: str) -> list[dict[str, Any]]:
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
return json.loads(r.read()) # pyright: ignore[reportAny]
def get_models(a: str) -> set[str]:
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
return set(json.loads(r.read())) # pyright: ignore[reportAny]
def run(h: str, a: str, body: bytes) -> None:
with urlopen(
Request(
f"http://{a}:52414/run_test",
data=body,
method="POST",
headers={"Content-Type": "application/json"},
),
timeout=300,
) as r: # pyright: ignore[reportAny]
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
print(f"\n{h}@{a}: {line}", flush=True)
with ThreadPoolExecutor(n) as exctr:
if kind in ("rdma", "both"):
payloads = list(exctr.map(get_tb, ips))
u2e = {
ident["domainUuid"]: (i, ident["rdmaInterface"])
for i, p in enumerate(payloads)
for d in p
for ident in cast(
list[dict[str, str]],
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
)
}
edges = {
(u2e[s][0], u2e[t][0]): u2e[t][1]
for p in payloads
for d in p
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
}
rdma_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
else:
rdma_devs = None
models = set[str].intersection(*exctr.map(get_models, ips))
print("\n")
print("=" * 70)
print(f"Starting test with {models}")
print("=" * 70)
print("\n")
for model in models:
body = json.dumps(
{"devs": devs, "model_id": model, "rdma_devs": rdma_devs, "kind": kind}
).encode()
list(exctr.map(run, hosts, ips, itertools.repeat(body)))

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done

39
tmp/run_exo_on.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -eq 0 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
echo "Deploying $commit to $# hosts..."
pids=""
trap 'xargs -r kill 2>/dev/null <<<"$pids" || true' EXIT INT TERM
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
ssh -t -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" "/usr/bin/env bash -lc '
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q $commit
nix develop -c uv sync
uv run exo
'" 2>&1 | sed -u "s/^/${colour}[${host}]${reset}/" &
pids+=" $!"
done
wait