diff --git a/.github/configs/bench_simple.yaml b/.github/configs/bench_simple.yaml
index 18f7042b..9a76b6db 100644
--- a/.github/configs/bench_simple.yaml
+++ b/.github/configs/bench_simple.yaml
@@ -26,7 +26,7 @@ model_ids:
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
# Sharding strategy: "Pipeline" or "Tensor"
-sharding: "Pipeline"
+sharding: "Tensor"
# Instance type: "MlxRing" or "MlxIbv"
instance_meta: "MlxIbv"
@@ -48,6 +48,11 @@ stages:
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
+ # - name: "pp64_g64"
+ # prompt_length: 64
+ # generation_length: 64
+ # time_between_requests: 2.0
+ # iterations: 5
# - name: "pp64_g512"
# prompt_length: 64
# generation_length: 512
@@ -58,6 +63,11 @@ stages:
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
+ - name: "pp256_g64"
+ prompt_length: 256
+ generation_length: 64
+ time_between_requests: 2.0
+ iterations: 5
# - name: "pp256_g512"
# prompt_length: 256
# generation_length: 512
@@ -83,26 +93,26 @@ stages:
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
- - name: "pp4096_g64"
- prompt_length: 4096
- generation_length: 64
- time_between_requests: 2.0
- iterations: 4
+ # - name: "pp4096_g64"
+ # prompt_length: 4096
+ # generation_length: 64
+ # time_between_requests: 2.0
+ # iterations: 4
# - name: "pp4096_g512"
# prompt_length: 4096
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
- - name: "pp8192_g64"
- prompt_length: 8192
- generation_length: 64
- time_between_requests: 2.0
- iterations: 4
+ # - name: "pp8192_g64"
+ # prompt_length: 8192
+ # generation_length: 64
+ # time_between_requests: 2.0
+ # iterations: 5
# - name: "pp8192_g512"
# prompt_length: 8192
# generation_length: 512
# time_between_requests: 2.0
- # iterations: 10
+ # iterations: 5
# - name: "pp16384_g64"
# prompt_length: 16384
# generation_length: 64
diff --git a/.mlx_typings/mlx_lm/models/cache.pyi b/.mlx_typings/mlx_lm/models/cache.pyi
index 30fe1b85..177dde3a 100644
--- a/.mlx_typings/mlx_lm/models/cache.pyi
+++ b/.mlx_typings/mlx_lm/models/cache.pyi
@@ -36,9 +36,7 @@ def save_prompt_cache(
state.
"""
-def load_prompt_cache(
- file_name, return_metadata=...
-): # -> tuple[list[Any], Any] | list[Any]:
+def load_prompt_cache(file_name: str, return_metadata=...) -> array:
"""
Load a prompt cache from a file.
diff --git a/dashboard/index.html b/dashboard/index.html
index 62ec32f5..d0ddc6fc 100644
--- a/dashboard/index.html
+++ b/dashboard/index.html
@@ -31,10 +31,10 @@
max-width: 1200px;
margin-bottom: 15px;
margin-top: 20px;
- text-align: left;
- display: flex;
- justify-content: space-between;
+ display: grid;
+ grid-template-columns: 1fr auto 1fr;
align-items: flex-end;
+ gap: 20px;
}
.dashboard-header h1 {
@@ -67,6 +67,18 @@
flex-direction: column;
}
+ .header-center {
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: center;
+ }
+
+ .header-right {
+ display: flex;
+ justify-content: flex-end;
+ }
+
.header-instances-button {
background-color: transparent;
border: 1px solid var(--exo-medium-gray);
@@ -972,11 +984,11 @@
@@ -986,16 +998,26 @@
+
+
@@ -1004,10 +1026,15 @@
@@ -1068,12 +1095,14 @@
const modelSelect = document.getElementById('modelSelect');
const launchInstanceButton = document.getElementById('launchInstanceButton');
const launchStatus = document.getElementById('launchStatus');
+ const minNodesOptions = document.getElementById('minNodesOptions');
const USE_MOCK_DATA = false; // <<< FLAG TO TOGGLE MOCK DATA
let currentlySelectedNodeId = null; // To store the ID of the currently selected node
let nodeIdToFriendlyName = {}; // Map nodeId -> friendly name for download sections
let instanceIdToColor = {}; // Map instanceId -> color for visual coding
let connectionToInstances = {}; // Map "nodeA|nodeB" -> [instanceIds] using that connection
+ let currentNodeCount = 1; // Track the current number of nodes in topology
const API_ENDPOINT = window.location.origin + window.location.pathname.replace(/\/$/, "") + '/state';
const REFRESH_INTERVAL = 1000; // 1 second
@@ -1218,6 +1247,16 @@
instancesMenuButton.classList.toggle('active', sidebarOpen);
}
+ // Edge IP display flag (can be toggled from console)
+ window.exoShowEdgeIPs = false;
+
+ // Helper function to toggle IP display (accessible from console)
+ window.toggleEdgeIPs = function() {
+ window.exoShowEdgeIPs = !window.exoShowEdgeIPs;
+ console.log(`Edge IP display ${window.exoShowEdgeIPs ? 'enabled' : 'disabled'}`);
+ return window.exoShowEdgeIPs;
+ };
+
// Fetch available models and populate dropdown
async function fetchAndPopulateModels() {
try {
@@ -1281,8 +1320,11 @@
const selectedSharding = document.querySelector('input[name="sharding"]:checked').value;
const selectedInstanceMeta = document.querySelector('input[name="instance_meta"]:checked').value;
+ const minNodesRadio = document.querySelector('input[name="min_nodes"]:checked');
+ const minNodes = minNodesRadio ? parseInt(minNodesRadio.value, 10) : 1;
console.log("selectedSharding", selectedSharding);
console.log("selectedInstanceMeta", selectedInstanceMeta);
+ console.log("minNodes", minNodes);
try {
showLaunchStatus('Launching instance...', 'loading');
@@ -1296,7 +1338,8 @@
body: JSON.stringify({
model_id: selectedModelId,
sharding: selectedSharding,
- instance_meta: selectedInstanceMeta
+ instance_meta: selectedInstanceMeta,
+ min_nodes: minNodes
})
});
@@ -1858,6 +1901,39 @@
const edgesData = (topologyData && Array.isArray(topologyData.edges)) ? topologyData.edges : [];
const nodeIds = Object.keys(nodesData);
+ // Update min nodes radio buttons based on current topology
+ currentNodeCount = Math.max(1, nodeIds.length);
+ if (minNodesOptions) {
+ // Get currently selected value before regenerating
+ const currentlySelected = document.querySelector('input[name="min_nodes"]:checked');
+ const hasOnlyDefaultOption = minNodesOptions.children.length === 1;
+ // Default to maximum nodes on initial load, otherwise preserve user selection
+ const selectedValue = (currentlySelected && !hasOnlyDefaultOption) ? parseInt(currentlySelected.value, 10) : currentNodeCount;
+
+ // Clear and regenerate radio buttons
+ minNodesOptions.innerHTML = '';
+ for (let i = 1; i <= currentNodeCount; i++) {
+ const optionDiv = document.createElement('div');
+ optionDiv.className = 'strategy-option';
+
+ const radio = document.createElement('input');
+ radio.type = 'radio';
+ radio.id = `minNodes${i}`;
+ radio.name = 'min_nodes';
+ radio.value = i.toString();
+ // Check if this should be selected (preserve selection or default to maximum)
+ radio.checked = (i === Math.min(selectedValue, currentNodeCount));
+
+ const label = document.createElement('label');
+ label.htmlFor = `minNodes${i}`;
+ label.textContent = i.toString();
+
+ optionDiv.appendChild(radio);
+ optionDiv.appendChild(label);
+ minNodesOptions.appendChild(optionDiv);
+ }
+ }
+
if (nodeIds.length === 0) {
const textEl = document.createElementNS('http://www.w3.org/2000/svg', 'text');
textEl.setAttribute('x', '50%');
@@ -2002,7 +2078,7 @@
arrowsGroup.appendChild(arrowSeg);
// Add label for A->B direction (show all connections)
- if (entry.aToBEdges && entry.aToBEdges.length > 0) {
+ if (window.exoShowEdgeIPs && entry.aToBEdges && entry.aToBEdges.length > 0) {
// Count occurrences of each IP/interface combination
const connectionCounts = new Map();
@@ -2067,7 +2143,7 @@
arrowsGroup.appendChild(arrowSeg);
// Add label for B->A direction (show all connections)
- if (entry.bToAEdges && entry.bToAEdges.length > 0) {
+ if (window.exoShowEdgeIPs && entry.bToAEdges && entry.bToAEdges.length > 0) {
// Count occurrences of each IP/interface combination
const connectionCounts = new Map();
diff --git a/src/exo/engines/mlx/auto_parallel.py b/src/exo/engines/mlx/auto_parallel.py
index 345454db..4ff747b8 100644
--- a/src/exo/engines/mlx/auto_parallel.py
+++ b/src/exo/engines/mlx/auto_parallel.py
@@ -3,7 +3,10 @@ from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Callable, Protocol, cast, override
-from mlx_lm.models.cache import KVCache, RotatingKVCache
+from mlx_lm.models.cache import (
+ KVCache,
+ _BaseCache, # pyright: ignore[reportPrivateUsage]
+)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.llama import Model as LlamaModel
@@ -91,7 +94,7 @@ class PipelineLastLayer(CustomMlxLayer):
x, *args, **kwargs
).arguments.get("cache", None)
- assert cache is None or isinstance(cache, (KVCache, RotatingKVCache))
+ assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
output: mx.array = self.original_layer(x, *args, **kwargs)
@@ -99,11 +102,7 @@ class PipelineLastLayer(CustomMlxLayer):
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
- if (
- cache is not None
- and hasattr(cache, "keys")
- and getattr(cache, "keys", None) is not None
- ):
+ if cache is not None:
# This change happened upstream - check out mlx github somewhere??
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
diff --git a/src/exo/engines/mlx/cache.py b/src/exo/engines/mlx/cache.py
new file mode 100644
index 00000000..f4e7df8d
--- /dev/null
+++ b/src/exo/engines/mlx/cache.py
@@ -0,0 +1,102 @@
+from copy import deepcopy
+from typing import Callable
+
+from mlx_lm import stream_generate
+from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
+from mlx_lm.tokenizer_utils import TokenizerWrapper
+
+import mlx.core as mx
+from exo.engines.mlx import Model
+from exo.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
+from exo.engines.mlx.utils_mlx import make_kv_cache
+
+
+class KVPrefixCache:
+ def __init__(self):
+ # Only one prefix cache per runner.
+ self.prompts: list[mx.array] = [] # mx array of tokens (ints)
+ self.caches: list[list[_BaseCache]] = []
+
+ def add_kv_cache(
+ self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
+ ):
+ tokenized_prompt = self.encode_prompt(tokenizer, prompt)
+ self.prompts.append(tokenized_prompt)
+ self.caches.append(deepcopy(cache))
+
+ def get_kv_cache(
+ self,
+ model: Model,
+ tokenizer: TokenizerWrapper,
+ sampler: Callable[[mx.array], mx.array],
+ prompt: str,
+ ) -> list[_BaseCache]:
+ tokenized_prompt = self.encode_prompt(tokenizer, prompt)
+ max_length = len(tokenized_prompt)
+
+ best_snapshot_index, best_snapshot_length = None, 0
+
+ for i, cached_prompt in enumerate(self.prompts):
+ length = _get_prefix_length(tokenized_prompt, cached_prompt)
+
+ if length == max_length:
+ return self.caches[i]
+
+ if length > best_snapshot_length:
+ best_snapshot_index, best_snapshot_length = i, length
+
+ if best_snapshot_index is not None:
+ prompt_cache = deepcopy(self.caches[best_snapshot_index])
+ trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
+ tokenized_prompt = tokenized_prompt[best_snapshot_index:]
+
+ else:
+ prompt_cache = make_kv_cache(
+ model,
+ # max_kv_size=MAX_KV_SIZE,
+ # keep=KEEP_KV_SIZE
+ )
+
+ prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
+
+ return prompt_cache
+
+ def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
+ add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
+ tokenizer.bos_token
+ )
+ tokenized_prompt = tokenizer.encode(
+ prompt, add_special_tokens=add_special_tokens
+ )
+ return mx.array(tokenized_prompt)
+
+
+def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
+ n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
+ if n == 0:
+ return 0
+
+ equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
+ prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
+ return int(mx.sum(prefix_mask).item())
+
+
+def prefill(
+ model: Model,
+ tokenizer: TokenizerWrapper,
+ sampler: Callable[[mx.array], mx.array],
+ prompt: mx.array,
+ cache: list[_BaseCache],
+) -> None:
+ for _ in stream_generate(
+ model=model,
+ tokenizer=tokenizer,
+ prompt=prompt,
+ max_tokens=0,
+ sampler=sampler,
+ prompt_cache=cache,
+ prefill_step_size=2048,
+ kv_group_size=KV_GROUP_SIZE,
+ kv_bits=KV_BITS,
+ ):
+ pass
diff --git a/src/exo/engines/mlx/constants.py b/src/exo/engines/mlx/constants.py
new file mode 100644
index 00000000..c73d62d3
--- /dev/null
+++ b/src/exo/engines/mlx/constants.py
@@ -0,0 +1,17 @@
+# TODO: Do we want so many constants?
+
+KV_GROUP_SIZE = 32
+KV_BITS = None
+ATTENTION_KV_BITS = 4
+MAX_TOKENS = 8192
+MAX_KV_SIZE = 3200
+KEEP_KV_SIZE = 1600
+QUANTIZE_MODEL_MODE = "affine"
+CACHE_GROUP_SIZE = 64
+KV_CACHE_BITS = 8
+TEMPERATURE = 1.0
+
+# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
+TRUST_REMOTE_CODE = True
+# TODO: Do we really want this?
+HIDE_THINKING = False
diff --git a/src/exo/engines/mlx/utils_mlx.py b/src/exo/engines/mlx/utils_mlx.py
index 5f42ca9c..8c48bd2e 100644
--- a/src/exo/engines/mlx/utils_mlx.py
+++ b/src/exo/engines/mlx/utils_mlx.py
@@ -1,8 +1,10 @@
import os
import resource
+import time
from typing import Any, Callable, cast
-from mlx_lm.models.cache import KVCache, RotatingKVCache
+from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
+from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -22,6 +24,14 @@ from exo.engines.mlx.auto_parallel import (
pipeline_auto_parallel,
tensor_auto_parallel,
)
+from exo.engines.mlx.constants import (
+ CACHE_GROUP_SIZE,
+ KV_CACHE_BITS,
+ PATCH_SYSTEM_PROMPT,
+ TEMPERATURE,
+ TRUST_REMOTE_CODE,
+)
+from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
@@ -44,7 +54,6 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
mlx_rank: None | int = None
mlx_world_size: None | int = None
-
def mx_barrier(group: mx.distributed.Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -87,7 +96,7 @@ def mlx_distributed_init(
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
"""
- rank = bound_instance.bound_shard().device_rank
+ rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
# TODO: singleton instances
@@ -136,33 +145,40 @@ def initialize_mlx(
"""
mx.random.seed(42)
- set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard()))
+ set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
- sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
+ sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
logger.info("Created a sampler")
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
logger.info(f"Single device used for {bound_instance.instance}")
- model_path = build_model_path(bound_instance.bound_shard().model_meta.model_id)
- model, _ = load_model(model_path, strict=True)
- # TODO: we should really make this opt-in, but Kimi requires trust_remote_code=True
- tokenizer = cast(
- TokenizerWrapper,
- load_tokenizer(
- model_path,
- tokenizer_config_extra={"trust_remote_code": True},
- # TODO: HACK for Kimi K2 wrong eos token id
- eos_token_ids=[163586] if "kimi-k2" in bound_instance.bound_shard().model_meta.model_id.lower() else None,
- ),
- )
- assert isinstance(tokenizer, TokenizerWrapper)
+ model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
+ start_time = time.perf_counter()
+ model, config = load_model(model_path, strict=True)
+ end_time = time.perf_counter()
+ logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
+ if isinstance(model.model, DeepseekV3Model):
+ pass
+ # model, config = quantize_model(
+ # model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
+ # )
+
+ tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
group = mlx_distributed_init(bound_instance)
- model, tokenizer = shard_and_load(bound_instance.bound_shard(), group=group)
- set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard()))
+ start_time = time.perf_counter()
+ model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
+ end_time = time.perf_counter()
+ logger.info(
+ f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
+ )
+
+ set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
+
+ logger.debug(model)
return cast(Model, model), tokenizer, sampler
@@ -174,20 +190,28 @@ def shard_and_load(
model_path = build_model_path(shard_metadata.model_meta.model_id)
model, config = load_model(model_path, lazy=True, strict=False)
- logger.info(f"{config=}")
+ logger.debug(model)
+ if isinstance(model.model, DeepseekV3Model):
+ pass
+ # TODO: See if we should quantize the model.
+ # def is_attention_layer(path: str) -> bool:
+ # path = path.lower()
+
+ # return "self_attn" in path and "layernorm" not in path
+
+
+ # def quant_predicate(path: str, module: nn.Module):
+ # if not isinstance(module, nn.Linear):
+ # return False
+
+ # return is_attention_layer(path)
+ # model, config = quantize_model(
+ # model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
+ # )
+
assert isinstance(model, nn.Module)
- # TODO: we should really make this opt-in, but Kimi requires trust_remote_code=True
- tokenizer = cast(
- TokenizerWrapper,
- # TODO: HACK for Kimi K2 wrong eos token id
- load_tokenizer(
- model_path,
- tokenizer_config_extra={"trust_remote_code": True},
- # TODO: HACK for Kimi K2 wrong eos token id
- eos_token_ids=[163586] if "kimi-k2" in shard_metadata.model_meta.model_id.lower() else None,
- ),
- )
+ tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
@@ -200,44 +224,63 @@ def shard_and_load(
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
+
+ # TODO: Do we need this?
mx.eval(model)
+ logger.debug("SHARDED")
+ logger.debug(model)
+
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
return model, tokenizer
+def get_tokenizer(model_path: str, shard_metadata: ShardMetadata):
+ tokenizer = cast(
+ TokenizerWrapper,
+ load_tokenizer(
+ model_path,
+ tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
+ # TODO: HACK for Kimi K2 wrong eos token id
+ eos_token_ids=[163586]
+ if "kimi-k2" in shard_metadata.model_meta.model_id.lower()
+ else None,
+ ),
+ )
+ assert isinstance(tokenizer, TokenizerWrapper)
+
+ return tokenizer
+
+
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
- messages_dicts: list[dict[str, Any]] = [msg.model_dump() for msg in messages]
- # Filter out None values, keeping relevant keys for the model
formatted_messages: list[dict[str, Any]] = []
- for message in messages_dicts:
- filtered_message: dict[str, Any] = {
- k: v
- for k, v in message.items() # pyright: ignore[reportAny]
- if v is not None
- }
+ for i, message in enumerate(messages):
+ if isinstance(message.content, ChatCompletionMessageText):
+ message.content = message.content.text
+ if isinstance(message.content, list):
+ if len(message.content) != 1:
+ logger.warning("Received malformed prompt")
+ continue
- # Verify we have required fields
- if "role" not in filtered_message:
- raise ValueError(f"Message missing 'role' field: {filtered_message}")
- if "content" not in filtered_message and "thinking" not in filtered_message:
- # If neither content nor thinking is present, skip this message
+ message.content = message.content[0].text
+ if message.content is None and message.thinking is None:
continue
- formatted_messages.append(filtered_message)
-
- messages_dicts = formatted_messages
+ # Null values are not valid when applying templates in tokenizer
+ formatted_messages.append(
+ {k: v for k, v in message.model_dump().items() if v is not None}
+ )
prompt: str = tokenizer.apply_chat_template( # type: ignore
- messages_dicts,
+ formatted_messages,
tokenize=False,
add_generation_prompt=True,
)
@@ -269,16 +312,23 @@ class NullKVCache(KVCache):
def make_kv_cache(
- model: Model,
- max_kv_size: int | None = None,
-) -> list[KVCache | RotatingKVCache]:
+ model: Model, max_kv_size: int | None = None, keep: int = 0
+) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
+
if max_kv_size is None:
- logger.info("Using default KV cache")
- return [KVCache() for _ in model.layers]
+ if KV_CACHE_BITS is None:
+ logger.info("Using default KV cache")
+ return [KVCache() for _ in model.layers]
+ else:
+ logger.info("Using quantized KV cache")
+ return [
+ QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
+ for _ in model.layers
+ ]
else:
- logger.info(f"Using rotating KV cache with {max_kv_size=}")
- return [RotatingKVCache(max_size=max_kv_size) for _ in model.layers]
+ logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
+ return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
def mlx_force_oom(size: int = 40000) -> None:
diff --git a/src/exo/main.py b/src/exo/main.py
index 110d44a6..b21434af 100644
--- a/src/exo/main.py
+++ b/src/exo/main.py
@@ -1,6 +1,6 @@
-import signal
import argparse
import multiprocessing as mp
+import signal
from dataclasses import dataclass
from typing import Self
@@ -8,14 +8,13 @@ import anyio
from anyio.abc import TaskGroup
from pydantic import PositiveInt
-from exo.shared.logging import logger
import exo.routing.topics as topics
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
-from exo.shared.logging import logger_cleanup, logger_setup
+from exo.shared.logging import logger, logger_cleanup, logger_setup
from exo.shared.types.commands import KillCommand
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
@@ -119,6 +118,7 @@ class Node:
# if this is our second call to shutdown, just sys.exit
if self._tg.cancel_scope.cancel_called:
import sys
+
sys.exit(1)
self._tg.cancel_scope.cancel()
@@ -208,7 +208,7 @@ class Node:
def main():
args = Args.parse()
-
+
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
diff --git a/src/exo/master/api.py b/src/exo/master/api.py
index 22074064..69176792 100644
--- a/src/exo/master/api.py
+++ b/src/exo/master/api.py
@@ -14,6 +14,7 @@ from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from loguru import logger
+from exo.engines.mlx.constants import HIDE_THINKING
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.models.model_cards import MODEL_CARDS
@@ -45,6 +46,7 @@ from exo.shared.types.models import ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId
+from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import OrderedBuffer
@@ -171,9 +173,9 @@ class API:
)
command = CreateInstance(
- command_id=CommandId(),
model_meta=model_meta,
instance_meta=payload.instance_meta,
+ min_nodes=payload.min_nodes,
sharding=payload.sharding,
)
await self._send(command)
@@ -194,7 +196,6 @@ class API:
raise HTTPException(status_code=404, detail="Instance not found")
command = DeleteInstance(
- command_id=CommandId(),
instance_id=instance_id,
)
await self._send(command)
@@ -212,17 +213,26 @@ class API:
self._chat_completion_queues[command_id] = asyncio.Queue()
finished = False
+ is_thinking = False
while not finished:
# TODO: how long should this timeout be?
chunk = await asyncio.wait_for(
self._chat_completion_queues[command_id].get(), timeout=600
)
assert isinstance(chunk, TokenChunk)
+ # TODO: Do we want this?
+ if HIDE_THINKING:
+ if chunk.text == "":
+ chunk.text = "\n"
+ if chunk.text == "":
+ chunk.text = "\n"
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
- yield f"data: {chunk_response.model_dump_json()}\n\n"
+
+ if not HIDE_THINKING or not is_thinking:
+ yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
@@ -244,31 +254,6 @@ class API:
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
- # Preprocess messages for GPT-OSS harmony format if needed
- # TODO: This is slop surely we get rid
- if "gpt-oss" in payload.model.lower():
- import re
-
- for message in payload.messages:
- if message.content and "<|channel|>" in message.content:
- # Parse harmony format tags
- thinking_pattern = r"<\|channel\|>(.*?)(?=<\|message\|>|$)"
- content_pattern = r"<\|message\|>(.*?)(?=<\|end\|>|$)"
-
- thinking_match = re.search(
- thinking_pattern, message.content, re.DOTALL
- )
- content_match = re.search(
- content_pattern, message.content, re.DOTALL
- )
-
- if content_match:
- # Extract the actual content
- message.content = content_match.group(1).strip()
- if thinking_match:
- # Store thinking in the thinking field
- message.thinking = thinking_match.group(1).strip()
-
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
@@ -279,7 +264,6 @@ class API:
)
command = ChatCompletion(
- command_id=CommandId(),
request_params=payload,
)
await self._send(command)
@@ -325,9 +309,19 @@ class API:
tg.start_soon(uvicorn_server.serve)
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
+ tg.start_soon(self._print_banner_when_ready, uvicorn_server)
self.command_sender.close()
self.global_event_receiver.close()
+ async def _print_banner_when_ready(self, uvicorn_server: uvicorn.Server):
+ """Wait for the uvicorn server to be ready, then print the startup banner."""
+ # TODO: Is this the best condition to check for?
+ # The point is this should log when exo is ready.
+ while not uvicorn_server.started:
+ await asyncio.sleep(0.1)
+
+ print_startup_banner(self.port)
+
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
diff --git a/src/exo/master/main.py b/src/exo/master/main.py
index 7f481cb5..5dadb5c3 100644
--- a/src/exo/master/main.py
+++ b/src/exo/master/main.py
@@ -209,7 +209,7 @@ class Master:
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
- # TODO: SQL
+ # TODO: SQL <- What does this mean?
self._event_log.append(event)
await self._send_event(indexed)
diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py
index fb49666f..7f345660 100644
--- a/src/exo/master/placement.py
+++ b/src/exo/master/placement.py
@@ -3,6 +3,8 @@ from collections.abc import Mapping
from copy import deepcopy
from typing import Sequence
+from loguru import logger
+
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
@@ -41,13 +43,13 @@ def get_instance_placements_after_create(
tb_only: bool = False,
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
- from loguru import logger
logger.info("finding cycles:")
cycles = topology.get_cycles()
- logger.info(f"{cycles=}")
singleton_cycles = [[node] for node in all_nodes]
- candidate_cycles = cycles + singleton_cycles
+ candidate_cycles = list(
+ filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
+ )
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)
diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py
index 1a4e7011..c96a8d35 100644
--- a/src/exo/master/placement_utils.py
+++ b/src/exo/master/placement_utils.py
@@ -210,27 +210,19 @@ def get_mlx_ibv_devices_matrix(
if i == j:
continue
- # just for debugging for now...
- for connection_ip in _find_connection_ip(node_i, node_j, cycle_digraph):
- interface_name = _find_interface_name_for_ip(connection_ip, node_i)
- logger.info(
- f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
- )
-
- matrix[i][j] = "rdma_en3" # TODO: hack, for now it's always en3
- continue
-
- for connection_ip in _find_connection_ip(node_i, node_j, cycle_digraph):
- # Set the first valid rmda i -> j connection - if there are multiple, we set essentially randomly - this is fine, the connection doesn't appear to have to be bidirectional
- if (
- interface_name := _find_interface_name_for_ip(
- connection_ip,
- node_i,
- )
- ) is not None:
+ # Find the IP J uses to talk to I
+ for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
+ # This is a local IP on I, which is attached to an interface: find that interface
+ if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
matrix[i][j] = interface_name
+ logger.info(
+ f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
+ )
break
else:
+ logger.warning(
+ f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
+ )
raise ValueError(
"Current ibv backend requires all-to-all rdma connections"
)
@@ -246,8 +238,9 @@ def _find_connection_ip(
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.list_connections():
if (
- connection.local_node_id == node_j.node_id
- and connection.send_back_node_id == node_i.node_id
+ connection.local_node_id == node_i.node_id
+ and connection.send_back_node_id == node_j.node_id
+ # TODO: Check if we need this.
and connection.send_back_multiaddr is not None
):
yield connection.send_back_multiaddr.ip_address
@@ -260,13 +253,13 @@ def _find_interface_name_for_ip(
if node_info.node_profile is None:
return None
+ logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
- logger.info(
- f"Checking interface {interface.name} for IP {interface.ip_address} == {ip_address}: {interface.ip_address == ip_address}"
- )
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
+ logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address == ip_address:
+ logger.info("Found")
return f"rdma_{interface.name}"
return None
diff --git a/src/exo/master/tests/conftest.py b/src/exo/master/tests/conftest.py
index 39aa2b31..9ebfa152 100644
--- a/src/exo/master/tests/conftest.py
+++ b/src/exo/master/tests/conftest.py
@@ -41,11 +41,15 @@ def create_node():
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
+ ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
+ nonlocal ip_counter
+ # assign unique ips
+ ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
@@ -53,7 +57,7 @@ def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
- address=f"/ip4/169.254.0.1/tcp/{send_back_port}"
+ address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py
index a8b33e8e..c52b0b33 100644
--- a/src/exo/master/tests/test_placement.py
+++ b/src/exo/master/tests/test_placement.py
@@ -1,6 +1,7 @@
from typing import Callable
import pytest
+from loguru import logger
from exo.master.placement import (
get_instance_placements_after_create,
@@ -356,10 +357,18 @@ def test_tensor_rdma_backend_connectivity_matrix(
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
+ conn_b_a = create_connection(node_id_b, node_id_a)
+ conn_c_b = create_connection(node_id_c, node_id_b)
+ conn_a_c = create_connection(node_id_a, node_id_c)
+
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
+ assert conn_b_a.send_back_multiaddr is not None
+ assert conn_c_b.send_back_multiaddr is not None
+ assert conn_a_c.send_back_multiaddr is not None
+
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
@@ -368,7 +377,12 @@ def test_tensor_rdma_backend_connectivity_matrix(
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
- ip_address=conn_a_b.send_back_multiaddr.ip_address,
+ ip_address=conn_c_a.send_back_multiaddr.ip_address,
+ type="rdma",
+ ),
+ NetworkInterfaceInfo(
+ name="en4",
+ ip_address=conn_b_a.send_back_multiaddr.ip_address,
type="rdma",
),
ethernet_interface,
@@ -381,9 +395,14 @@ def test_tensor_rdma_backend_connectivity_matrix(
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
+ NetworkInterfaceInfo(
+ name="en3",
+ ip_address=conn_c_b.send_back_multiaddr.ip_address,
+ type="rdma",
+ ),
NetworkInterfaceInfo(
name="en4",
- ip_address=conn_b_c.send_back_multiaddr.ip_address,
+ ip_address=conn_a_b.send_back_multiaddr.ip_address,
type="rdma",
),
ethernet_interface,
@@ -397,8 +416,13 @@ def test_tensor_rdma_backend_connectivity_matrix(
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
- name="en5",
- ip_address=conn_c_a.send_back_multiaddr.ip_address,
+ name="en3",
+ ip_address=conn_a_c.send_back_multiaddr.ip_address,
+ type="rdma",
+ ),
+ NetworkInterfaceInfo(
+ name="en4",
+ ip_address=conn_b_c.send_back_multiaddr.ip_address,
type="rdma",
),
ethernet_interface,
@@ -412,6 +436,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
+ topology.add_connection(conn_b_a)
+ topology.add_connection(conn_c_b)
+ topology.add_connection(conn_a_c)
create_instance_command = CreateInstance(
command_id=CommandId(),
@@ -444,9 +471,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
- assert matrix[idx_a][idx_b] == "rdma_en3"
- assert matrix[idx_b][idx_c] == "rdma_en4"
- assert matrix[idx_c][idx_a] == "rdma_en5"
+ logger.info(matrix)
+
+ assert matrix[idx_a][idx_b] == "rdma_en4"
+ assert matrix[idx_b][idx_c] == "rdma_en3"
+ assert matrix[idx_c][idx_a] == "rdma_en3"
assert ":" in instance.mlx_ibv_coordinator
assert not instance.mlx_ibv_coordinator.startswith("169.254")
diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py
index 16cc6adb..6ea031a7 100644
--- a/src/exo/shared/apply.py
+++ b/src/exo/shared/apply.py
@@ -252,9 +252,5 @@ def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> Sta
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
- if not topology.contains_connection(event.edge) and topology.contains_connection(
- event.edge.reverse()
- ):
- topology.remove_connection(event.edge.reverse())
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})
diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py
index 12051b3b..6368a72d 100644
--- a/src/exo/shared/models/model_cards.py
+++ b/src/exo/shared/models/model_cards.py
@@ -14,32 +14,32 @@ class ModelCard(CamelCaseModel):
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
- "deepseek-v3-0324:4bit": ModelCard(
- short_id="deepseek-v3-0324:4bit",
- model_id="mlx-community/DeepSeek-V3-0324-4bit",
- name="DeepSeek V3 0324 (4-bit)",
- description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
- tags=[],
- metadata=ModelMetadata(
- model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
- pretty_name="DeepSeek V3 0324 (4-bit)",
- storage_size=Memory.from_kb(409706307),
- n_layers=61,
- ),
- ),
- "deepseek-v3-0324": ModelCard(
- short_id="deepseek-v3-0324",
- model_id="mlx-community/DeepSeek-v3-0324-8bit",
- name="DeepSeek V3 0324 (8-bit)",
- description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
- tags=[],
- metadata=ModelMetadata(
- model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
- pretty_name="DeepSeek V3 0324 (8-bit)",
- storage_size=Memory.from_kb(754706307),
- n_layers=61,
- ),
- ),
+ # "deepseek-v3-0324:4bit": ModelCard(
+ # short_id="deepseek-v3-0324:4bit",
+ # model_id="mlx-community/DeepSeek-V3-0324-4bit",
+ # name="DeepSeek V3 0324 (4-bit)",
+ # description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
+ # tags=[],
+ # metadata=ModelMetadata(
+ # model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
+ # pretty_name="DeepSeek V3 0324 (4-bit)",
+ # storage_size=Memory.from_kb(409706307),
+ # n_layers=61,
+ # ),
+ # ),
+ # "deepseek-v3-0324": ModelCard(
+ # short_id="deepseek-v3-0324",
+ # model_id="mlx-community/DeepSeek-v3-0324-8bit",
+ # name="DeepSeek V3 0324 (8-bit)",
+ # description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
+ # tags=[],
+ # metadata=ModelMetadata(
+ # model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
+ # pretty_name="DeepSeek V3 0324 (8-bit)",
+ # storage_size=Memory.from_kb(754706307),
+ # n_layers=61,
+ # ),
+ # ),
"deepseek-v3.1": ModelCard(
short_id="deepseek-v3.1",
model_id="mlx-community/DeepSeek-V3.1-8bit",
@@ -67,32 +67,32 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
),
# deepseek r1
- "deepseek-r1-0528:4bit": ModelCard(
- short_id="deepseek-r1-0528:4bit",
- model_id="mlx-community/DeepSeek-R1-0528-4bit",
- name="DeepSeek-R1-0528 (4-bit)",
- description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
- tags=[],
- metadata=ModelMetadata(
- model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
- pretty_name="DeepSeek R1 671B (4-bit)",
- storage_size=Memory.from_kb(409706307),
- n_layers=61,
- ),
- ),
- "deepseek-r1-0528": ModelCard(
- short_id="deepseek-r1-0528",
- model_id="mlx-community/DeepSeek-R1-0528-8bit",
- name="DeepSeek-R1-0528 (8-bit)",
- description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
- tags=[],
- metadata=ModelMetadata(
- model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
- pretty_name="DeepSeek R1 671B (8-bit)",
- storage_size=Memory.from_bytes(754998771712),
- n_layers=61,
- ),
- ),
+ # "deepseek-r1-0528:4bit": ModelCard(
+ # short_id="deepseek-r1-0528:4bit",
+ # model_id="mlx-community/DeepSeek-R1-0528-4bit",
+ # name="DeepSeek-R1-0528 (4-bit)",
+ # description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
+ # tags=[],
+ # metadata=ModelMetadata(
+ # model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
+ # pretty_name="DeepSeek R1 671B (4-bit)",
+ # storage_size=Memory.from_kb(409706307),
+ # n_layers=61,
+ # ),
+ # ),
+ # "deepseek-r1-0528": ModelCard(
+ # short_id="deepseek-r1-0528",
+ # model_id="mlx-community/DeepSeek-R1-0528-8bit",
+ # name="DeepSeek-R1-0528 (8-bit)",
+ # description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
+ # tags=[],
+ # metadata=ModelMetadata(
+ # model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
+ # pretty_name="DeepSeek R1 671B (8-bit)",
+ # storage_size=Memory.from_bytes(754998771712),
+ # n_layers=61,
+ # ),
+ # ),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
@@ -228,19 +228,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=32,
),
),
- "phi-3-mini:128k": ModelCard(
- short_id="phi-3-mini:128k",
- model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
- name="Phi 3 Mini 128k",
- description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
- tags=[],
- metadata=ModelMetadata(
- model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
- pretty_name="Phi 3 Mini 128k",
- storage_size=Memory.from_kb(2099262),
- n_layers=32,
- ),
- ),
+ # "phi-3-mini:128k": ModelCard(
+ # short_id="phi-3-mini:128k",
+ # model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
+ # name="Phi 3 Mini 128k",
+ # description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
+ # tags=[],
+ # metadata=ModelMetadata(
+ # model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
+ # pretty_name="Phi 3 Mini 128k",
+ # storage_size=Memory.from_kb(2099262),
+ # n_layers=32,
+ # ),
+ # ),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
@@ -268,19 +268,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=48,
),
),
- "qwen3-235b-a22b": ModelCard(
- short_id="qwen3-235b-a22b",
- model_id="mlx-community/Qwen3-235B-A22B-4bit",
- name="Qwen3 235B, Active 22B (4-bit)",
- description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
- tags=[],
- metadata=ModelMetadata(
- model_id=ModelId("mlx-community/Qwen3-235B-A22B-4bit"),
- pretty_name="Qwen3 235B, Active 22B (4-bit)",
- storage_size=Memory.from_kb(123207680),
- n_layers=94,
- ),
- ),
+ # "qwen3-235b-a22b": ModelCard(
+ # short_id="qwen3-235b-a22b",
+ # model_id="mlx-community/Qwen3-235B-A22B-4bit",
+ # name="Qwen3 235B, Active 22B (4-bit)",
+ # description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
+ # tags=[],
+ # metadata=ModelMetadata(
+ # model_id=ModelId("mlx-community/Qwen3-235B-A22B-4bit"),
+ # pretty_name="Qwen3 235B, Active 22B (4-bit)",
+ # storage_size=Memory.from_kb(123207680),
+ # n_layers=94,
+ # ),
+ # ),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id="mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit",
@@ -308,31 +308,31 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=40,
),
),
- "granite-3.3-8b": ModelCard(
- short_id="granite-3.3-8b",
- model_id="mlx-community/granite-3.3-8b-instruct-fp16",
- name="Granite 3.3 8B",
- description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
- tags=[],
- metadata=ModelMetadata(
- model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
- pretty_name="Granite 3.3 8B",
- storage_size=Memory.from_kb(15958720),
- n_layers=40,
- ),
- ),
+ # "granite-3.3-8b": ModelCard(
+ # short_id="granite-3.3-8b",
+ # model_id="mlx-community/granite-3.3-8b-instruct-fp16",
+ # name="Granite 3.3 8B",
+ # description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
+ # tags=[],
+ # metadata=ModelMetadata(
+ # model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
+ # pretty_name="Granite 3.3 8B",
+ # storage_size=Memory.from_kb(15958720),
+ # n_layers=40,
+ # ),
+ # ),
# smol-lm
- "smol-lm-135m": ModelCard(
- short_id="smol-lm-135m",
- model_id="mlx-community/SmolLM-135M-4bit",
- name="Smol LM 135M",
- description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
- tags=[],
- metadata=ModelMetadata(
- model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
- pretty_name="Smol LM 135M",
- storage_size=Memory.from_kb(73940),
- n_layers=30,
- ),
- ),
+ # "smol-lm-135m": ModelCard(
+ # short_id="smol-lm-135m",
+ # model_id="mlx-community/SmolLM-135M-4bit",
+ # name="Smol LM 135M",
+ # description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
+ # tags=[],
+ # metadata=ModelMetadata(
+ # model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
+ # pretty_name="Smol LM 135M",
+ # storage_size=Memory.from_kb(73940),
+ # n_layers=30,
+ # ),
+ # ),
}
diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py
index c88e1f59..7413161f 100644
--- a/src/exo/shared/topology.py
+++ b/src/exo/shared/topology.py
@@ -70,6 +70,9 @@ class Topology:
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
+ if connection in self._edge_id_to_rx_id_map:
+ return
+
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
diff --git a/src/exo/shared/types/api.py b/src/exo/shared/types/api.py
index 131fc7e2..3ec61289 100644
--- a/src/exo/shared/types/api.py
+++ b/src/exo/shared/types/api.py
@@ -24,13 +24,20 @@ class ModelListModel(BaseModel):
class ModelList(BaseModel):
- object: str = "list"
+ object: Literal["list"] = "list"
data: list[ModelListModel]
+class ChatCompletionMessageText(BaseModel):
+ type: Literal["text"] = "text"
+ text: str
+
+
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
- content: str | None = None
+ content: (
+ str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
+ ) = None
thinking: str | None = None # Added for GPT-OSS harmony format support
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
@@ -55,20 +62,6 @@ class Logprobs(BaseModel):
content: list[LogprobsContentItem] | None = None
-class StreamingChoiceResponse(BaseModel):
- index: int
- delta: ChatCompletionMessage
- logprobs: Logprobs | None = None
- finish_reason: FinishReason | None = None
-
-
-class ChatCompletionChoice(BaseModel):
- index: int
- message: ChatCompletionMessage
- logprobs: Logprobs | None = None
- finish_reason: FinishReason | None = None
-
-
class PromptTokensDetails(BaseModel):
cached_tokens: int = 0
audio_tokens: int = 0
@@ -89,6 +82,21 @@ class Usage(BaseModel):
completion_tokens_details: CompletionTokensDetails | None = None
+class StreamingChoiceResponse(BaseModel):
+ index: int
+ delta: ChatCompletionMessage
+ logprobs: Logprobs | None = None
+ finish_reason: FinishReason | None = None
+ usage: Usage | None = None
+
+
+class ChatCompletionChoice(BaseModel):
+ index: int
+ message: ChatCompletionMessage
+ logprobs: Logprobs | None = None
+ finish_reason: FinishReason | None = None
+
+
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
@@ -125,8 +133,8 @@ class CreateInstanceTaskParams(BaseModel):
# TODO: in future the user could specify a specific Instance, not just a model_id
model_id: str
sharding: Sharding = Sharding.Pipeline
- # TODO: fix
instance_meta: InstanceMeta = InstanceMeta.MlxRing
+ min_nodes: int = 1
class DeleteInstanceTaskParams(BaseModel):
diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py
index 9ea2aa3f..1deca8ff 100644
--- a/src/exo/shared/types/commands.py
+++ b/src/exo/shared/types/commands.py
@@ -29,6 +29,7 @@ class CreateInstance(BaseCommand):
model_meta: ModelMetadata
sharding: Sharding
instance_meta: InstanceMeta
+ min_nodes: int
class DeleteInstance(BaseCommand):
diff --git a/src/exo/shared/types/topology.py b/src/exo/shared/types/topology.py
index 1695a98b..33d7c752 100644
--- a/src/exo/shared/types/topology.py
+++ b/src/exo/shared/types/topology.py
@@ -12,20 +12,17 @@ class NodeInfo(CamelCaseModel):
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
- send_back_multiaddr: Multiaddr | None
+ send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
- if self.send_back_multiaddr:
- return hash(
- (
- self.local_node_id,
- self.send_back_node_id,
- self.send_back_multiaddr.address,
- )
+ return hash(
+ (
+ self.local_node_id,
+ self.send_back_node_id,
+ self.send_back_multiaddr.address,
)
- else:
- return hash((self.local_node_id, self.send_back_node_id))
+ )
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
@@ -37,13 +34,4 @@ class Connection(CamelCaseModel):
)
def is_thunderbolt(self) -> bool:
- return self.send_back_multiaddr is not None and str(
- self.send_back_multiaddr.ipv4_address
- ).startswith("169.254")
-
- def reverse(self) -> "Connection":
- return Connection(
- local_node_id=self.send_back_node_id,
- send_back_node_id=self.local_node_id,
- send_back_multiaddr=None,
- )
+ return str(self.send_back_multiaddr.ipv4_address).startswith("200.0")
diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py
index 9230001f..b68e60a4 100644
--- a/src/exo/shared/types/worker/instances.py
+++ b/src/exo/shared/types/worker/instances.py
@@ -41,6 +41,7 @@ class BoundInstance(CamelCaseModel):
instance: Instance
bound_runner_id: RunnerId
+ @property
def bound_shard(self) -> ShardMetadata:
shard = self.instance.shard(self.bound_runner_id)
assert shard is not None
diff --git a/src/exo/shared/types/worker/runners.py b/src/exo/shared/types/worker/runners.py
index da8544a3..5cceb83b 100644
--- a/src/exo/shared/types/worker/runners.py
+++ b/src/exo/shared/types/worker/runners.py
@@ -48,6 +48,7 @@ class RunnerRunning(BaseRunnerStatus):
class RunnerShutdown(BaseRunnerStatus):
pass
+
class RunnerFailed(BaseRunnerStatus):
error_message: str | None = None
diff --git a/src/exo/utils/banner.py b/src/exo/utils/banner.py
new file mode 100644
index 00000000..cae6eac3
--- /dev/null
+++ b/src/exo/utils/banner.py
@@ -0,0 +1,34 @@
+def print_startup_banner(port: int) -> None:
+ """Print a prominent startup banner with API endpoint information."""
+ banner = """
+╔═══════════════════════════════════════════════════════════════════════╗
+║ ║
+║ ███████╗██╗ ██╗ ██████╗ ║
+║ ██╔════╝╚██╗██╔╝██╔═══██╗ ║
+║ █████╗ ╚███╔╝ ██║ ██║ ║
+║ ██╔══╝ ██╔██╗ ██║ ██║ ║
+║ ███████╗██╔╝ ██╗╚██████╔╝ ║
+║ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ║
+║ ║
+║ Distributed AI Inference Cluster ║
+║ ║
+╚═══════════════════════════════════════════════════════════════════════╝
+"""
+
+ dashboard_url = f"http://localhost:{port}"
+
+ api_info = f"""
+╔═══════════════════════════════════════════════════════════════════════╗
+║ ║
+║ 🌐 Dashboard & API Ready ║
+║ ║
+║ {dashboard_url}{" " * (69 - len(dashboard_url))}║
+║ ║
+║ Click the URL above to open the dashboard in your browser ║
+║ ║
+╚═══════════════════════════════════════════════════════════════════════╝
+"""
+
+ print(banner)
+ print(api_info)
+ print()
diff --git a/src/exo/utils/channels.py b/src/exo/utils/channels.py
index 70971cf3..c335fb02 100644
--- a/src/exo/utils/channels.py
+++ b/src/exo/utils/channels.py
@@ -139,7 +139,9 @@ class MpSender[T]:
# == unique to Mp channels ==
def join(self) -> None:
"""Ensure any queued messages are resolved before continuing"""
- assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
+ assert self._state.closed.is_set(), (
+ "Mp channels must be closed before being joined"
+ )
self._state.buffer.join_thread()
# == context manager support ==
@@ -209,7 +211,9 @@ class MpReceiver[T]:
# == unique to Mp channels ==
def join(self) -> None:
"""Block until all enqueued messages are drained off our side of the buffer"""
- assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
+ assert self._state.closed.is_set(), (
+ "Mp channels must be closed before being joined"
+ )
self._state.buffer.join_thread()
# == iterator support ==
diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py
index dfdda537..e44b1975 100644
--- a/src/exo/worker/plan.py
+++ b/src/exo/worker/plan.py
@@ -100,12 +100,12 @@ def _model_needs_download(
for runner in runners.values():
if (
isinstance(runner.status, RunnerWaitingForModel)
- and runner.bound_instance.bound_shard() not in download_status
+ and runner.bound_instance.bound_shard not in download_status
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
instance_id=runner.bound_instance.instance.instance_id,
- shard_metadata=runner.bound_instance.bound_shard(),
+ shard_metadata=runner.bound_instance.bound_shard,
)
@@ -160,7 +160,7 @@ def _ready_to_warmup(
)
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
)
- and runner.bound_instance.bound_shard().device_rank != 0
+ and runner.bound_instance.bound_shard.device_rank != 0
)
or (
all(
@@ -170,7 +170,7 @@ def _ready_to_warmup(
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
if global_runner_id != runner.bound_instance.bound_runner_id
)
- and runner.bound_instance.bound_shard().device_rank == 0
+ and runner.bound_instance.bound_shard.device_rank == 0
)
):
return StartWarmup(instance_id=runner.bound_instance.instance.instance_id)
diff --git a/src/exo/worker/runner/generate.py b/src/exo/worker/runner/generate.py
index 09d51b6c..134ac956 100644
--- a/src/exo/worker/runner/generate.py
+++ b/src/exo/worker/runner/generate.py
@@ -6,6 +6,9 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.engines.mlx import Model
+
+# from exo.engines.mlx.cache import KVPrefixCache
+from exo.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
@@ -70,6 +73,8 @@ def warmup_inference(
sampler=sampler,
prompt_cache=cache,
prefill_step_size=65536,
+ kv_group_size=KV_GROUP_SIZE,
+ kv_bits=KV_BITS,
):
logger.info("Generated warmup token: " + str(_r.text))
tokens_generated += 1
@@ -94,19 +99,19 @@ def mlx_generate(
chat_task_data=task,
)
- cache = make_kv_cache(
- model=model,
- )
+ caches = make_kv_cache(model=model)
- max_tokens = task.max_tokens or 1000
+ max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
- prompt_cache=cache,
+ prompt_cache=caches,
prefill_step_size=65536,
+ kv_group_size=KV_GROUP_SIZE,
+ kv_bits=KV_BITS,
):
logger.info(out.text)
if out.finish_reason is not None and out.finish_reason not in get_args(
diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py
index f2b23e35..87eb742d 100644
--- a/src/exo/worker/runner/runner.py
+++ b/src/exo/worker/runner/runner.py
@@ -22,7 +22,7 @@ from exo.shared.types.tasks import (
)
from exo.shared.types.worker.commands_runner import (
GenerationResponse,
- TokenizedResponse,
+ # TokenizedResponse,
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
@@ -31,12 +31,12 @@ from exo.shared.types.worker.runners import (
RunnerLoading,
RunnerReady,
RunnerRunning,
+ RunnerShutdown,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
- RunnerShutdown
)
-from exo.utils.channels import MpReceiver, MpSender, ClosedResourceError
+from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.generate import mlx_generate, warmup_inference
@@ -49,7 +49,7 @@ def main(
instance, runner_id, shard_metadata = (
bound_instance.instance,
bound_instance.bound_runner_id,
- bound_instance.bound_shard(),
+ bound_instance.bound_shard,
)
try:
logger.info("hello from the runner")
@@ -115,6 +115,7 @@ def main(
model=model,
tokenizer=tokenizer,
sampler=sampler,
+ # kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -185,9 +186,9 @@ def main(
),
)
)
- case TokenizedResponse():
- # TODO: something here ig
- logger.info("Finished tokenizing?")
+ # case TokenizedResponse():
+ # TODO: something here ig
+ # logger.info("Finished tokenizing?")
current_status = RunnerReady()
logger.info("runner ready")
@@ -212,9 +213,7 @@ def main(
)
)
event_sender.send(
- RunnerStatusUpdated(
- runner_id=runner_id, runner_status=RunnerShutdown()
- )
+ RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
)
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")
diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py
index 768cefa8..cda356ae 100644
--- a/src/exo/worker/runner/runner_supervisor.py
+++ b/src/exo/worker/runner/runner_supervisor.py
@@ -67,7 +67,7 @@ class RunnerSupervisor:
daemon=True,
)
- shard_metadata = bound_instance.bound_shard()
+ shard_metadata = bound_instance.bound_shard
self = cls(
bound_instance=bound_instance,
@@ -109,12 +109,13 @@ class RunnerSupervisor:
if not self.runner_process.is_alive():
return
- logger.critical("Runner process didn't respond to SIGKILL. System resources may have leaked")
+ logger.critical(
+ "Runner process didn't respond to SIGKILL. System resources may have leaked"
+ )
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
-
async def start_task(self, task: Task):
event = anyio.Event()
@@ -126,7 +127,6 @@ class RunnerSupervisor:
return
await event.wait()
-
async def _forward_events(self):
with self._ev_recv as events:
try:
@@ -140,7 +140,6 @@ class RunnerSupervisor:
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)
-
def __del__(self) -> None:
if self.runner_process.is_alive():
logger.warning("RunnerSupervisor was not stopped cleanly.")
@@ -152,7 +151,7 @@ class RunnerSupervisor:
await to_thread.run_sync(self.runner_process.join, 1)
rc = self.runner_process.exitcode
if rc == 0:
- #
+ #
return
if isinstance(rc, int) and rc < 0:
diff --git a/src/exo/worker/tests/test_plan/test_worker_plan.py b/src/exo/worker/tests/test_plan/test_worker_plan.py
index 02f9612d..c555edd4 100644
--- a/src/exo/worker/tests/test_plan/test_worker_plan.py
+++ b/src/exo/worker/tests/test_plan/test_worker_plan.py
@@ -1,4 +1,5 @@
import pytest
+from exo.worker.common import AssignedRunner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.state import State
@@ -27,7 +28,6 @@ from exo.shared.types.worker.runners import (
RunningRunnerStatus,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
-from exo.worker.common import AssignedRunner
from exo.worker.main import Worker
from exo.worker.plan import plan
from exo.worker.tests.constants import (
diff --git a/tmp/run_llm.sh b/tmp/run_llm.sh
index 07599c2d..b9dbb61b 100755
--- a/tmp/run_llm.sh
+++ b/tmp/run_llm.sh
@@ -13,9 +13,9 @@ QUERY="$*"
curl -sN -X POST "http://$HOST:8000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{
- \"model\": \"mlx-community/Llama-3.3-70B-Instruct-8bit\",
+ \"model\": \"mlx-community/Kimi-K2-Thinking\",
\"stream\": true,
- \"messages\": [{ \"role\": \"user\", \"content\": \"$QUERY\" }]
+ \"messages\": [{ \"role\": \"user\", \"content\": \"$QUERY\"}]
}" |
grep --line-buffered '^data:' |
grep --line-buffered -v 'data: \[DONE\]' |