Co-authored-by: Evan <evanev7@gmail.com>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
rltakashige
2025-11-20 20:03:51 +00:00
committed by GitHub
parent d793f5f96c
commit 28a91787e8
30 changed files with 645 additions and 332 deletions

View File

@@ -26,7 +26,7 @@ model_ids:
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
# Sharding strategy: "Pipeline" or "Tensor"
sharding: "Pipeline"
sharding: "Tensor"
# Instance type: "MlxRing" or "MlxIbv"
instance_meta: "MlxIbv"
@@ -48,6 +48,11 @@ stages:
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g64"
# prompt_length: 64
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g512"
# prompt_length: 64
# generation_length: 512
@@ -58,6 +63,11 @@ stages:
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
- name: "pp256_g64"
prompt_length: 256
generation_length: 64
time_between_requests: 2.0
iterations: 5
# - name: "pp256_g512"
# prompt_length: 256
# generation_length: 512
@@ -83,26 +93,26 @@ stages:
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
- name: "pp4096_g64"
prompt_length: 4096
generation_length: 64
time_between_requests: 2.0
iterations: 4
# - name: "pp4096_g64"
# prompt_length: 4096
# generation_length: 64
# time_between_requests: 2.0
# iterations: 4
# - name: "pp4096_g512"
# prompt_length: 4096
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
- name: "pp8192_g64"
prompt_length: 8192
generation_length: 64
time_between_requests: 2.0
iterations: 4
# - name: "pp8192_g64"
# prompt_length: 8192
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp8192_g512"
# prompt_length: 8192
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# iterations: 5
# - name: "pp16384_g64"
# prompt_length: 16384
# generation_length: 64

View File

@@ -36,9 +36,7 @@ def save_prompt_cache(
state.
"""
def load_prompt_cache(
file_name, return_metadata=...
): # -> tuple[list[Any], Any] | list[Any]:
def load_prompt_cache(file_name: str, return_metadata=...) -> array:
"""
Load a prompt cache from a file.

View File

@@ -31,10 +31,10 @@
max-width: 1200px;
margin-bottom: 15px;
margin-top: 20px;
text-align: left;
display: flex;
justify-content: space-between;
display: grid;
grid-template-columns: 1fr auto 1fr;
align-items: flex-end;
gap: 20px;
}
.dashboard-header h1 {
@@ -67,6 +67,18 @@
flex-direction: column;
}
.header-center {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
.header-right {
display: flex;
justify-content: flex-end;
}
.header-instances-button {
background-color: transparent;
border: 1px solid var(--exo-medium-gray);
@@ -972,11 +984,11 @@
<label class="launch-label">Sharding:</label>
<div class="strategy-options">
<div class="strategy-option">
<input type="radio" id="shardingPipeline" name="sharding" value="Pipeline" checked>
<input type="radio" id="shardingPipeline" name="sharding" value="Pipeline">
<label for="shardingPipeline">Pipeline</label>
</div>
<div class="strategy-option">
<input type="radio" id="shardingTensor" name="sharding" value="Tensor">
<input type="radio" id="shardingTensor" name="sharding" value="Tensor" checked>
<label for="shardingTensor">Tensor</label>
</div>
</div>
@@ -986,16 +998,26 @@
<label class="launch-label">Instance Type:</label>
<div class="strategy-options">
<div class="strategy-option">
<input type="radio" id="instanceMlxRing" name="instance_meta" value="MlxRing" checked>
<input type="radio" id="instanceMlxRing" name="instance_meta" value="MlxRing">
<label for="instanceMlxRing">MLX Ring</label>
</div>
<div class="strategy-option">
<input type="radio" id="instanceMlxIbv" name="instance_meta" value="MlxIbv">
<input type="radio" id="instanceMlxIbv" name="instance_meta" value="MlxIbv" checked>
<label for="instanceMlxIbv">MLX IBV</label>
</div>
</div>
</div>
<div class="strategy-selector">
<label class="launch-label">Minimum Nodes:</label>
<div class="strategy-options" id="minNodesOptions">
<div class="strategy-option">
<input type="radio" id="minNodes1" name="min_nodes" value="1" checked>
<label for="minNodes1">1</label>
</div>
</div>
</div>
<button id="launchInstanceButton" class="launch-button" disabled>Launch Instance</button>
<div id="launchStatus" class="launch-status"></div>
</div>
@@ -1004,10 +1026,15 @@
<div class="dashboard-header">
<div class="header-left">
<!-- Left section: empty or can be used for future content -->
</div>
<div class="header-center">
<h1><img src="exo-logo.png" alt="EXO logo" height="48" /></h1>
<p class="last-updated" id="lastUpdated">Fetching data...</p>
</div>
<button class="header-instances-button" id="instancesMenuButton">Instances</button>
<div class="header-right">
<button class="header-instances-button" id="instancesMenuButton">Instances</button>
</div>
</div>
<!-- Replaced node-grid with SVG container for topology graph -->
@@ -1068,12 +1095,14 @@
const modelSelect = document.getElementById('modelSelect');
const launchInstanceButton = document.getElementById('launchInstanceButton');
const launchStatus = document.getElementById('launchStatus');
const minNodesOptions = document.getElementById('minNodesOptions');
const USE_MOCK_DATA = false; // <<< FLAG TO TOGGLE MOCK DATA
let currentlySelectedNodeId = null; // To store the ID of the currently selected node
let nodeIdToFriendlyName = {}; // Map nodeId -> friendly name for download sections
let instanceIdToColor = {}; // Map instanceId -> color for visual coding
let connectionToInstances = {}; // Map "nodeA|nodeB" -> [instanceIds] using that connection
let currentNodeCount = 1; // Track the current number of nodes in topology
const API_ENDPOINT = window.location.origin + window.location.pathname.replace(/\/$/, "") + '/state';
const REFRESH_INTERVAL = 1000; // 1 second
@@ -1218,6 +1247,16 @@
instancesMenuButton.classList.toggle('active', sidebarOpen);
}
// Edge IP display flag (can be toggled from console)
window.exoShowEdgeIPs = false;
// Helper function to toggle IP display (accessible from console)
window.toggleEdgeIPs = function() {
window.exoShowEdgeIPs = !window.exoShowEdgeIPs;
console.log(`Edge IP display ${window.exoShowEdgeIPs ? 'enabled' : 'disabled'}`);
return window.exoShowEdgeIPs;
};
// Fetch available models and populate dropdown
async function fetchAndPopulateModels() {
try {
@@ -1281,8 +1320,11 @@
const selectedSharding = document.querySelector('input[name="sharding"]:checked').value;
const selectedInstanceMeta = document.querySelector('input[name="instance_meta"]:checked').value;
const minNodesRadio = document.querySelector('input[name="min_nodes"]:checked');
const minNodes = minNodesRadio ? parseInt(minNodesRadio.value, 10) : 1;
console.log("selectedSharding", selectedSharding);
console.log("selectedInstanceMeta", selectedInstanceMeta);
console.log("minNodes", minNodes);
try {
showLaunchStatus('Launching instance...', 'loading');
@@ -1296,7 +1338,8 @@
body: JSON.stringify({
model_id: selectedModelId,
sharding: selectedSharding,
instance_meta: selectedInstanceMeta
instance_meta: selectedInstanceMeta,
min_nodes: minNodes
})
});
@@ -1858,6 +1901,39 @@
const edgesData = (topologyData && Array.isArray(topologyData.edges)) ? topologyData.edges : [];
const nodeIds = Object.keys(nodesData);
// Update min nodes radio buttons based on current topology
currentNodeCount = Math.max(1, nodeIds.length);
if (minNodesOptions) {
// Get currently selected value before regenerating
const currentlySelected = document.querySelector('input[name="min_nodes"]:checked');
const hasOnlyDefaultOption = minNodesOptions.children.length === 1;
// Default to maximum nodes on initial load, otherwise preserve user selection
const selectedValue = (currentlySelected && !hasOnlyDefaultOption) ? parseInt(currentlySelected.value, 10) : currentNodeCount;
// Clear and regenerate radio buttons
minNodesOptions.innerHTML = '';
for (let i = 1; i <= currentNodeCount; i++) {
const optionDiv = document.createElement('div');
optionDiv.className = 'strategy-option';
const radio = document.createElement('input');
radio.type = 'radio';
radio.id = `minNodes${i}`;
radio.name = 'min_nodes';
radio.value = i.toString();
// Check if this should be selected (preserve selection or default to maximum)
radio.checked = (i === Math.min(selectedValue, currentNodeCount));
const label = document.createElement('label');
label.htmlFor = `minNodes${i}`;
label.textContent = i.toString();
optionDiv.appendChild(radio);
optionDiv.appendChild(label);
minNodesOptions.appendChild(optionDiv);
}
}
if (nodeIds.length === 0) {
const textEl = document.createElementNS('http://www.w3.org/2000/svg', 'text');
textEl.setAttribute('x', '50%');
@@ -2002,7 +2078,7 @@
arrowsGroup.appendChild(arrowSeg);
// Add label for A->B direction (show all connections)
if (entry.aToBEdges && entry.aToBEdges.length > 0) {
if (window.exoShowEdgeIPs && entry.aToBEdges && entry.aToBEdges.length > 0) {
// Count occurrences of each IP/interface combination
const connectionCounts = new Map();
@@ -2067,7 +2143,7 @@
arrowsGroup.appendChild(arrowSeg);
// Add label for B->A direction (show all connections)
if (entry.bToAEdges && entry.bToAEdges.length > 0) {
if (window.exoShowEdgeIPs && entry.bToAEdges && entry.bToAEdges.length > 0) {
// Count occurrences of each IP/interface combination
const connectionCounts = new Map();

View File

@@ -3,7 +3,10 @@ from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Callable, Protocol, cast, override
from mlx_lm.models.cache import KVCache, RotatingKVCache
from mlx_lm.models.cache import (
KVCache,
_BaseCache, # pyright: ignore[reportPrivateUsage]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.llama import Model as LlamaModel
@@ -91,7 +94,7 @@ class PipelineLastLayer(CustomMlxLayer):
x, *args, **kwargs
).arguments.get("cache", None)
assert cache is None or isinstance(cache, (KVCache, RotatingKVCache))
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
output: mx.array = self.original_layer(x, *args, **kwargs)
@@ -99,11 +102,7 @@ class PipelineLastLayer(CustomMlxLayer):
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if (
cache is not None
and hasattr(cache, "keys")
and getattr(cache, "keys", None) is not None
):
if cache is not None:
# This change happened upstream - check out mlx github somewhere??
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]

View File

@@ -0,0 +1,102 @@
from copy import deepcopy
from typing import Callable
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
import mlx.core as mx
from exo.engines.mlx import Model
from exo.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.engines.mlx.utils_mlx import make_kv_cache
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = []
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
):
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
max_length = len(tokenized_prompt)
best_snapshot_index, best_snapshot_length = None, 0
for i, cached_prompt in enumerate(self.prompts):
length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
return self.caches[i]
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
)
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
return prompt_cache
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
if n == 0:
return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass

View File

@@ -0,0 +1,17 @@
# TODO: Do we want so many constants?
KV_GROUP_SIZE = 32
KV_BITS = None
ATTENTION_KV_BITS = 4
MAX_TOKENS = 8192
MAX_KV_SIZE = 3200
KEEP_KV_SIZE = 1600
QUANTIZE_MODEL_MODE = "affine"
CACHE_GROUP_SIZE = 64
KV_CACHE_BITS = 8
TEMPERATURE = 1.0
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE = True
# TODO: Do we really want this?
HIDE_THINKING = False

View File

@@ -1,8 +1,10 @@
import os
import resource
import time
from typing import Any, Callable, cast
from mlx_lm.models.cache import KVCache, RotatingKVCache
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -22,6 +24,14 @@ from exo.engines.mlx.auto_parallel import (
pipeline_auto_parallel,
tensor_auto_parallel,
)
from exo.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
PATCH_SYSTEM_PROMPT,
TEMPERATURE,
TRUST_REMOTE_CODE,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
@@ -44,7 +54,6 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
mlx_rank: None | int = None
mlx_world_size: None | int = None
def mx_barrier(group: mx.distributed.Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -87,7 +96,7 @@ def mlx_distributed_init(
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
"""
rank = bound_instance.bound_shard().device_rank
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
# TODO: singleton instances
@@ -136,33 +145,40 @@ def initialize_mlx(
"""
mx.random.seed(42)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard()))
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
logger.info("Created a sampler")
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard().model_meta.model_id)
model, _ = load_model(model_path, strict=True)
# TODO: we should really make this opt-in, but Kimi requires trust_remote_code=True
tokenizer = cast(
TokenizerWrapper,
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": True},
# TODO: HACK for Kimi K2 wrong eos token id
eos_token_ids=[163586] if "kimi-k2" in bound_instance.bound_shard().model_meta.model_id.lower() else None,
),
)
assert isinstance(tokenizer, TokenizerWrapper)
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, config = load_model(model_path, strict=True)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
if isinstance(model.model, DeepseekV3Model):
pass
# model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
# )
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
group = mlx_distributed_init(bound_instance)
model, tokenizer = shard_and_load(bound_instance.bound_shard(), group=group)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard()))
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
logger.debug(model)
return cast(Model, model), tokenizer, sampler
@@ -174,20 +190,28 @@ def shard_and_load(
model_path = build_model_path(shard_metadata.model_meta.model_id)
model, config = load_model(model_path, lazy=True, strict=False)
logger.info(f"{config=}")
logger.debug(model)
if isinstance(model.model, DeepseekV3Model):
pass
# TODO: See if we should quantize the model.
# def is_attention_layer(path: str) -> bool:
# path = path.lower()
# return "self_attn" in path and "layernorm" not in path
# def quant_predicate(path: str, module: nn.Module):
# if not isinstance(module, nn.Linear):
# return False
# return is_attention_layer(path)
# model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
# )
assert isinstance(model, nn.Module)
# TODO: we should really make this opt-in, but Kimi requires trust_remote_code=True
tokenizer = cast(
TokenizerWrapper,
# TODO: HACK for Kimi K2 wrong eos token id
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": True},
# TODO: HACK for Kimi K2 wrong eos token id
eos_token_ids=[163586] if "kimi-k2" in shard_metadata.model_meta.model_id.lower() else None,
),
)
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
@@ -200,44 +224,63 @@ def shard_and_load(
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# TODO: Do we need this?
mx.eval(model)
logger.debug("SHARDED")
logger.debug(model)
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
return model, tokenizer
def get_tokenizer(model_path: str, shard_metadata: ShardMetadata):
tokenizer = cast(
TokenizerWrapper,
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
# TODO: HACK for Kimi K2 wrong eos token id
eos_token_ids=[163586]
if "kimi-k2" in shard_metadata.model_meta.model_id.lower()
else None,
),
)
assert isinstance(tokenizer, TokenizerWrapper)
return tokenizer
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
messages_dicts: list[dict[str, Any]] = [msg.model_dump() for msg in messages]
# Filter out None values, keeping relevant keys for the model
formatted_messages: list[dict[str, Any]] = []
for message in messages_dicts:
filtered_message: dict[str, Any] = {
k: v
for k, v in message.items() # pyright: ignore[reportAny]
if v is not None
}
for i, message in enumerate(messages):
if isinstance(message.content, ChatCompletionMessageText):
message.content = message.content.text
if isinstance(message.content, list):
if len(message.content) != 1:
logger.warning("Received malformed prompt")
continue
# Verify we have required fields
if "role" not in filtered_message:
raise ValueError(f"Message missing 'role' field: {filtered_message}")
if "content" not in filtered_message and "thinking" not in filtered_message:
# If neither content nor thinking is present, skip this message
message.content = message.content[0].text
if message.content is None and message.thinking is None:
continue
formatted_messages.append(filtered_message)
messages_dicts = formatted_messages
# Null values are not valid when applying templates in tokenizer
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None}
)
prompt: str = tokenizer.apply_chat_template( # type: ignore
messages_dicts,
formatted_messages,
tokenize=False,
add_generation_prompt=True,
)
@@ -269,16 +312,23 @@ class NullKVCache(KVCache):
def make_kv_cache(
model: Model,
max_kv_size: int | None = None,
) -> list[KVCache | RotatingKVCache]:
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
if max_kv_size is None:
logger.info("Using default KV cache")
return [KVCache() for _ in model.layers]
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")
return [KVCache() for _ in model.layers]
else:
logger.info("Using quantized KV cache")
return [
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
for _ in model.layers
]
else:
logger.info(f"Using rotating KV cache with {max_kv_size=}")
return [RotatingKVCache(max_size=max_kv_size) for _ in model.layers]
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
def mlx_force_oom(size: int = 40000) -> None:

View File

@@ -1,6 +1,6 @@
import signal
import argparse
import multiprocessing as mp
import signal
from dataclasses import dataclass
from typing import Self
@@ -8,14 +8,13 @@ import anyio
from anyio.abc import TaskGroup
from pydantic import PositiveInt
from exo.shared.logging import logger
import exo.routing.topics as topics
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.logging import logger, logger_cleanup, logger_setup
from exo.shared.types.commands import KillCommand
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
@@ -119,6 +118,7 @@ class Node:
# if this is our second call to shutdown, just sys.exit
if self._tg.cancel_scope.cancel_called:
import sys
sys.exit(1)
self._tg.cancel_scope.cancel()
@@ -208,7 +208,7 @@ class Node:
def main():
args = Args.parse()
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)

View File

@@ -14,6 +14,7 @@ from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from loguru import logger
from exo.engines.mlx.constants import HIDE_THINKING
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.models.model_cards import MODEL_CARDS
@@ -45,6 +46,7 @@ from exo.shared.types.models import ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import OrderedBuffer
@@ -171,9 +173,9 @@ class API:
)
command = CreateInstance(
command_id=CommandId(),
model_meta=model_meta,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
sharding=payload.sharding,
)
await self._send(command)
@@ -194,7 +196,6 @@ class API:
raise HTTPException(status_code=404, detail="Instance not found")
command = DeleteInstance(
command_id=CommandId(),
instance_id=instance_id,
)
await self._send(command)
@@ -212,17 +213,26 @@ class API:
self._chat_completion_queues[command_id] = asyncio.Queue()
finished = False
is_thinking = False
while not finished:
# TODO: how long should this timeout be?
chunk = await asyncio.wait_for(
self._chat_completion_queues[command_id].get(), timeout=600
)
assert isinstance(chunk, TokenChunk)
# TODO: Do we want this?
if HIDE_THINKING:
if chunk.text == "<think>":
chunk.text = "\n"
if chunk.text == "</think>":
chunk.text = "\n"
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if not HIDE_THINKING or not is_thinking:
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
@@ -244,31 +254,6 @@ class API:
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
# Preprocess messages for GPT-OSS harmony format if needed
# TODO: This is slop surely we get rid
if "gpt-oss" in payload.model.lower():
import re
for message in payload.messages:
if message.content and "<|channel|>" in message.content:
# Parse harmony format tags
thinking_pattern = r"<\|channel\|>(.*?)(?=<\|message\|>|$)"
content_pattern = r"<\|message\|>(.*?)(?=<\|end\|>|$)"
thinking_match = re.search(
thinking_pattern, message.content, re.DOTALL
)
content_match = re.search(
content_pattern, message.content, re.DOTALL
)
if content_match:
# Extract the actual content
message.content = content_match.group(1).strip()
if thinking_match:
# Store thinking in the thinking field
message.thinking = thinking_match.group(1).strip()
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
@@ -279,7 +264,6 @@ class API:
)
command = ChatCompletion(
command_id=CommandId(),
request_params=payload,
)
await self._send(command)
@@ -325,9 +309,19 @@ class API:
tg.start_soon(uvicorn_server.serve)
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._print_banner_when_ready, uvicorn_server)
self.command_sender.close()
self.global_event_receiver.close()
async def _print_banner_when_ready(self, uvicorn_server: uvicorn.Server):
"""Wait for the uvicorn server to be ready, then print the startup banner."""
# TODO: Is this the best condition to check for?
# The point is this should log when exo is ready.
while not uvicorn_server.started:
await asyncio.sleep(0.1)
print_startup_banner(self.port)
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -209,7 +209,7 @@ class Master:
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
# TODO: SQL
# TODO: SQL <- What does this mean?
self._event_log.append(event)
await self._send_event(indexed)

View File

@@ -3,6 +3,8 @@ from collections.abc import Mapping
from copy import deepcopy
from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
@@ -41,13 +43,13 @@ def get_instance_placements_after_create(
tb_only: bool = False,
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
from loguru import logger
logger.info("finding cycles:")
cycles = topology.get_cycles()
logger.info(f"{cycles=}")
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = cycles + singleton_cycles
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)

View File

@@ -210,27 +210,19 @@ def get_mlx_ibv_devices_matrix(
if i == j:
continue
# just for debugging for now...
for connection_ip in _find_connection_ip(node_i, node_j, cycle_digraph):
interface_name = _find_interface_name_for_ip(connection_ip, node_i)
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
matrix[i][j] = "rdma_en3" # TODO: hack, for now it's always en3
continue
for connection_ip in _find_connection_ip(node_i, node_j, cycle_digraph):
# Set the first valid rmda i -> j connection - if there are multiple, we set essentially randomly - this is fine, the connection doesn't appear to have to be bidirectional
if (
interface_name := _find_interface_name_for_ip(
connection_ip,
node_i,
)
) is not None:
# Find the IP J uses to talk to I
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
break
else:
logger.warning(
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current ibv backend requires all-to-all rdma connections"
)
@@ -246,8 +238,9 @@ def _find_connection_ip(
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_j.node_id
and connection.send_back_node_id == node_i.node_id
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
# TODO: Check if we need this.
and connection.send_back_multiaddr is not None
):
yield connection.send_back_multiaddr.ip_address
@@ -260,13 +253,13 @@ def _find_interface_name_for_ip(
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
logger.info(
f"Checking interface {interface.name} for IP {interface.ip_address} == {ip_address}: {interface.ip_address == ip_address}"
)
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address == ip_address:
logger.info("Found")
return f"rdma_{interface.name}"
return None

View File

@@ -41,11 +41,15 @@ def create_node():
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
@@ -53,7 +57,7 @@ def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.1/tcp/{send_back_port}"
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000

View File

@@ -1,6 +1,7 @@
from typing import Callable
import pytest
from loguru import logger
from exo.master.placement import (
get_instance_placements_after_create,
@@ -356,10 +357,18 @@ def test_tensor_rdma_backend_connectivity_matrix(
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
@@ -368,7 +377,12 @@ def test_tensor_rdma_backend_connectivity_matrix(
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
ip_address=conn_c_a.send_back_multiaddr.ip_address,
type="rdma",
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
type="rdma",
),
ethernet_interface,
@@ -381,9 +395,14 @@ def test_tensor_rdma_backend_connectivity_matrix(
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
type="rdma",
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
ip_address=conn_a_b.send_back_multiaddr.ip_address,
type="rdma",
),
ethernet_interface,
@@ -397,8 +416,13 @@ def test_tensor_rdma_backend_connectivity_matrix(
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en5",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
type="rdma",
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
type="rdma",
),
ethernet_interface,
@@ -412,6 +436,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
create_instance_command = CreateInstance(
command_id=CommandId(),
@@ -444,9 +471,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
assert ":" in instance.mlx_ibv_coordinator
assert not instance.mlx_ibv_coordinator.startswith("169.254")

View File

@@ -252,9 +252,5 @@ def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> Sta
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
if not topology.contains_connection(event.edge) and topology.contains_connection(
event.edge.reverse()
):
topology.remove_connection(event.edge.reverse())
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -14,32 +14,32 @@ class ModelCard(CamelCaseModel):
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3-0324:4bit": ModelCard(
short_id="deepseek-v3-0324:4bit",
model_id="mlx-community/DeepSeek-V3-0324-4bit",
name="DeepSeek V3 0324 (4-bit)",
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
pretty_name="DeepSeek V3 0324 (4-bit)",
storage_size=Memory.from_kb(409706307),
n_layers=61,
),
),
"deepseek-v3-0324": ModelCard(
short_id="deepseek-v3-0324",
model_id="mlx-community/DeepSeek-v3-0324-8bit",
name="DeepSeek V3 0324 (8-bit)",
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
pretty_name="DeepSeek V3 0324 (8-bit)",
storage_size=Memory.from_kb(754706307),
n_layers=61,
),
),
# "deepseek-v3-0324:4bit": ModelCard(
# short_id="deepseek-v3-0324:4bit",
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
# name="DeepSeek V3 0324 (4-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
# pretty_name="DeepSeek V3 0324 (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# ),
# ),
# "deepseek-v3-0324": ModelCard(
# short_id="deepseek-v3-0324",
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
# name="DeepSeek V3 0324 (8-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
# pretty_name="DeepSeek V3 0324 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# ),
# ),
"deepseek-v3.1": ModelCard(
short_id="deepseek-v3.1",
model_id="mlx-community/DeepSeek-V3.1-8bit",
@@ -67,32 +67,32 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
),
# deepseek r1
"deepseek-r1-0528:4bit": ModelCard(
short_id="deepseek-r1-0528:4bit",
model_id="mlx-community/DeepSeek-R1-0528-4bit",
name="DeepSeek-R1-0528 (4-bit)",
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
pretty_name="DeepSeek R1 671B (4-bit)",
storage_size=Memory.from_kb(409706307),
n_layers=61,
),
),
"deepseek-r1-0528": ModelCard(
short_id="deepseek-r1-0528",
model_id="mlx-community/DeepSeek-R1-0528-8bit",
name="DeepSeek-R1-0528 (8-bit)",
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
pretty_name="DeepSeek R1 671B (8-bit)",
storage_size=Memory.from_bytes(754998771712),
n_layers=61,
),
),
# "deepseek-r1-0528:4bit": ModelCard(
# short_id="deepseek-r1-0528:4bit",
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
# name="DeepSeek-R1-0528 (4-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
# pretty_name="DeepSeek R1 671B (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# ),
# ),
# "deepseek-r1-0528": ModelCard(
# short_id="deepseek-r1-0528",
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
# name="DeepSeek-R1-0528 (8-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
# pretty_name="DeepSeek R1 671B (8-bit)",
# storage_size=Memory.from_bytes(754998771712),
# n_layers=61,
# ),
# ),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
@@ -228,19 +228,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=32,
),
),
"phi-3-mini:128k": ModelCard(
short_id="phi-3-mini:128k",
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
name="Phi 3 Mini 128k",
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
pretty_name="Phi 3 Mini 128k",
storage_size=Memory.from_kb(2099262),
n_layers=32,
),
),
# "phi-3-mini:128k": ModelCard(
# short_id="phi-3-mini:128k",
# model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
# name="Phi 3 Mini 128k",
# description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
# pretty_name="Phi 3 Mini 128k",
# storage_size=Memory.from_kb(2099262),
# n_layers=32,
# ),
# ),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
@@ -268,19 +268,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=48,
),
),
"qwen3-235b-a22b": ModelCard(
short_id="qwen3-235b-a22b",
model_id="mlx-community/Qwen3-235B-A22B-4bit",
name="Qwen3 235B, Active 22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-4bit"),
pretty_name="Qwen3 235B, Active 22B (4-bit)",
storage_size=Memory.from_kb(123207680),
n_layers=94,
),
),
# "qwen3-235b-a22b": ModelCard(
# short_id="qwen3-235b-a22b",
# model_id="mlx-community/Qwen3-235B-A22B-4bit",
# name="Qwen3 235B, Active 22B (4-bit)",
# description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Qwen3-235B-A22B-4bit"),
# pretty_name="Qwen3 235B, Active 22B (4-bit)",
# storage_size=Memory.from_kb(123207680),
# n_layers=94,
# ),
# ),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id="mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit",
@@ -308,31 +308,31 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=40,
),
),
"granite-3.3-8b": ModelCard(
short_id="granite-3.3-8b",
model_id="mlx-community/granite-3.3-8b-instruct-fp16",
name="Granite 3.3 8B",
description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
pretty_name="Granite 3.3 8B",
storage_size=Memory.from_kb(15958720),
n_layers=40,
),
),
# "granite-3.3-8b": ModelCard(
# short_id="granite-3.3-8b",
# model_id="mlx-community/granite-3.3-8b-instruct-fp16",
# name="Granite 3.3 8B",
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
# pretty_name="Granite 3.3 8B",
# storage_size=Memory.from_kb(15958720),
# n_layers=40,
# ),
# ),
# smol-lm
"smol-lm-135m": ModelCard(
short_id="smol-lm-135m",
model_id="mlx-community/SmolLM-135M-4bit",
name="Smol LM 135M",
description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
pretty_name="Smol LM 135M",
storage_size=Memory.from_kb(73940),
n_layers=30,
),
),
# "smol-lm-135m": ModelCard(
# short_id="smol-lm-135m",
# model_id="mlx-community/SmolLM-135M-4bit",
# name="Smol LM 135M",
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
# pretty_name="Smol LM 135M",
# storage_size=Memory.from_kb(73940),
# n_layers=30,
# ),
# ),
}

View File

@@ -70,6 +70,9 @@ class Topology:
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
return
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]

View File

@@ -24,13 +24,20 @@ class ModelListModel(BaseModel):
class ModelList(BaseModel):
object: str = "list"
object: Literal["list"] = "list"
data: list[ModelListModel]
class ChatCompletionMessageText(BaseModel):
type: Literal["text"] = "text"
text: str
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: str | None = None
content: (
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
) = None
thinking: str | None = None # Added for GPT-OSS harmony format support
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
@@ -55,20 +62,6 @@ class Logprobs(BaseModel):
content: list[LogprobsContentItem] | None = None
class StreamingChoiceResponse(BaseModel):
index: int
delta: ChatCompletionMessage
logprobs: Logprobs | None = None
finish_reason: FinishReason | None = None
class ChatCompletionChoice(BaseModel):
index: int
message: ChatCompletionMessage
logprobs: Logprobs | None = None
finish_reason: FinishReason | None = None
class PromptTokensDetails(BaseModel):
cached_tokens: int = 0
audio_tokens: int = 0
@@ -89,6 +82,21 @@ class Usage(BaseModel):
completion_tokens_details: CompletionTokensDetails | None = None
class StreamingChoiceResponse(BaseModel):
index: int
delta: ChatCompletionMessage
logprobs: Logprobs | None = None
finish_reason: FinishReason | None = None
usage: Usage | None = None
class ChatCompletionChoice(BaseModel):
index: int
message: ChatCompletionMessage
logprobs: Logprobs | None = None
finish_reason: FinishReason | None = None
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
@@ -125,8 +133,8 @@ class CreateInstanceTaskParams(BaseModel):
# TODO: in future the user could specify a specific Instance, not just a model_id
model_id: str
sharding: Sharding = Sharding.Pipeline
# TODO: fix
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
class DeleteInstanceTaskParams(BaseModel):

View File

@@ -29,6 +29,7 @@ class CreateInstance(BaseCommand):
model_meta: ModelMetadata
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
class DeleteInstance(BaseCommand):

View File

@@ -12,20 +12,17 @@ class NodeInfo(CamelCaseModel):
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr | None
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
if self.send_back_multiaddr:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
else:
return hash((self.local_node_id, self.send_back_node_id))
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
@@ -37,13 +34,4 @@ class Connection(CamelCaseModel):
)
def is_thunderbolt(self) -> bool:
return self.send_back_multiaddr is not None and str(
self.send_back_multiaddr.ipv4_address
).startswith("169.254")
def reverse(self) -> "Connection":
return Connection(
local_node_id=self.send_back_node_id,
send_back_node_id=self.local_node_id,
send_back_multiaddr=None,
)
return str(self.send_back_multiaddr.ipv4_address).startswith("200.0")

View File

@@ -41,6 +41,7 @@ class BoundInstance(CamelCaseModel):
instance: Instance
bound_runner_id: RunnerId
@property
def bound_shard(self) -> ShardMetadata:
shard = self.instance.shard(self.bound_runner_id)
assert shard is not None

View File

@@ -48,6 +48,7 @@ class RunnerRunning(BaseRunnerStatus):
class RunnerShutdown(BaseRunnerStatus):
pass
class RunnerFailed(BaseRunnerStatus):
error_message: str | None = None

34
src/exo/utils/banner.py Normal file
View File

@@ -0,0 +1,34 @@
def print_startup_banner(port: int) -> None:
"""Print a prominent startup banner with API endpoint information."""
banner = """
╔═══════════════════════════════════════════════════════════════════════╗
║ ║
║ ███████╗██╗ ██╗ ██████╗ ║
║ ██╔════╝╚██╗██╔╝██╔═══██╗ ║
║ █████╗ ╚███╔╝ ██║ ██║ ║
║ ██╔══╝ ██╔██╗ ██║ ██║ ║
║ ███████╗██╔╝ ██╗╚██████╔╝ ║
║ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ║
║ ║
║ Distributed AI Inference Cluster ║
║ ║
╚═══════════════════════════════════════════════════════════════════════╝
"""
dashboard_url = f"http://localhost:{port}"
api_info = f"""
╔═══════════════════════════════════════════════════════════════════════╗
║ ║
║ 🌐 Dashboard & API Ready ║
║ ║
{dashboard_url}{" " * (69 - len(dashboard_url))}
║ ║
║ Click the URL above to open the dashboard in your browser ║
║ ║
╚═══════════════════════════════════════════════════════════════════════╝
"""
print(banner)
print(api_info)
print()

View File

@@ -139,7 +139,9 @@ class MpSender[T]:
# == unique to Mp channels ==
def join(self) -> None:
"""Ensure any queued messages are resolved before continuing"""
assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
assert self._state.closed.is_set(), (
"Mp channels must be closed before being joined"
)
self._state.buffer.join_thread()
# == context manager support ==
@@ -209,7 +211,9 @@ class MpReceiver[T]:
# == unique to Mp channels ==
def join(self) -> None:
"""Block until all enqueued messages are drained off our side of the buffer"""
assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
assert self._state.closed.is_set(), (
"Mp channels must be closed before being joined"
)
self._state.buffer.join_thread()
# == iterator support ==

View File

@@ -100,12 +100,12 @@ def _model_needs_download(
for runner in runners.values():
if (
isinstance(runner.status, RunnerWaitingForModel)
and runner.bound_instance.bound_shard() not in download_status
and runner.bound_instance.bound_shard not in download_status
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
instance_id=runner.bound_instance.instance.instance_id,
shard_metadata=runner.bound_instance.bound_shard(),
shard_metadata=runner.bound_instance.bound_shard,
)
@@ -160,7 +160,7 @@ def _ready_to_warmup(
)
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
)
and runner.bound_instance.bound_shard().device_rank != 0
and runner.bound_instance.bound_shard.device_rank != 0
)
or (
all(
@@ -170,7 +170,7 @@ def _ready_to_warmup(
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
if global_runner_id != runner.bound_instance.bound_runner_id
)
and runner.bound_instance.bound_shard().device_rank == 0
and runner.bound_instance.bound_shard.device_rank == 0
)
):
return StartWarmup(instance_id=runner.bound_instance.instance.instance_id)

View File

@@ -6,6 +6,9 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.engines.mlx import Model
# from exo.engines.mlx.cache import KVPrefixCache
from exo.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
@@ -70,6 +73,8 @@ def warmup_inference(
sampler=sampler,
prompt_cache=cache,
prefill_step_size=65536,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info("Generated warmup token: " + str(_r.text))
tokens_generated += 1
@@ -94,19 +99,19 @@ def mlx_generate(
chat_task_data=task,
)
cache = make_kv_cache(
model=model,
)
caches = make_kv_cache(model=model)
max_tokens = task.max_tokens or 1000
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
prompt_cache=cache,
prompt_cache=caches,
prefill_step_size=65536,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info(out.text)
if out.finish_reason is not None and out.finish_reason not in get_args(

View File

@@ -22,7 +22,7 @@ from exo.shared.types.tasks import (
)
from exo.shared.types.worker.commands_runner import (
GenerationResponse,
TokenizedResponse,
# TokenizedResponse,
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
@@ -31,12 +31,12 @@ from exo.shared.types.worker.runners import (
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
RunnerShutdown
)
from exo.utils.channels import MpReceiver, MpSender, ClosedResourceError
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.generate import mlx_generate, warmup_inference
@@ -49,7 +49,7 @@ def main(
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,
bound_instance.bound_shard(),
bound_instance.bound_shard,
)
try:
logger.info("hello from the runner")
@@ -115,6 +115,7 @@ def main(
model=model,
tokenizer=tokenizer,
sampler=sampler,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -185,9 +186,9 @@ def main(
),
)
)
case TokenizedResponse():
# TODO: something here ig
logger.info("Finished tokenizing?")
# case TokenizedResponse():
# TODO: something here ig
# logger.info("Finished tokenizing?")
current_status = RunnerReady()
logger.info("runner ready")
@@ -212,9 +213,7 @@ def main(
)
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerShutdown()
)
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
)
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")

View File

@@ -67,7 +67,7 @@ class RunnerSupervisor:
daemon=True,
)
shard_metadata = bound_instance.bound_shard()
shard_metadata = bound_instance.bound_shard
self = cls(
bound_instance=bound_instance,
@@ -109,12 +109,13 @@ class RunnerSupervisor:
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGKILL. System resources may have leaked")
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
event = anyio.Event()
@@ -126,7 +127,6 @@ class RunnerSupervisor:
return
await event.wait()
async def _forward_events(self):
with self._ev_recv as events:
try:
@@ -140,7 +140,6 @@ class RunnerSupervisor:
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)
def __del__(self) -> None:
if self.runner_process.is_alive():
logger.warning("RunnerSupervisor was not stopped cleanly.")
@@ -152,7 +151,7 @@ class RunnerSupervisor:
await to_thread.run_sync(self.runner_process.join, 1)
rc = self.runner_process.exitcode
if rc == 0:
#
#
return
if isinstance(rc, int) and rc < 0:

View File

@@ -1,4 +1,5 @@
import pytest
from exo.worker.common import AssignedRunner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.state import State
@@ -27,7 +28,6 @@ from exo.shared.types.worker.runners import (
RunningRunnerStatus,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.common import AssignedRunner
from exo.worker.main import Worker
from exo.worker.plan import plan
from exo.worker.tests.constants import (

View File

@@ -13,9 +13,9 @@ QUERY="$*"
curl -sN -X POST "http://$HOST:8000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{
\"model\": \"mlx-community/Llama-3.3-70B-Instruct-8bit\",
\"model\": \"mlx-community/Kimi-K2-Thinking\",
\"stream\": true,
\"messages\": [{ \"role\": \"user\", \"content\": \"$QUERY\" }]
\"messages\": [{ \"role\": \"user\", \"content\": \"$QUERY\"}]
}" |
grep --line-buffered '^data:' |
grep --line-buffered -v 'data: \[DONE\]' |