diff --git a/.github/configs/bench_simple.yaml b/.github/configs/bench_simple.yaml index 18f7042b..9a76b6db 100644 --- a/.github/configs/bench_simple.yaml +++ b/.github/configs/bench_simple.yaml @@ -26,7 +26,7 @@ model_ids: # - "mlx-community/Llama-3.2-1B-Instruct-4bit" # Sharding strategy: "Pipeline" or "Tensor" -sharding: "Pipeline" +sharding: "Tensor" # Instance type: "MlxRing" or "MlxIbv" instance_meta: "MlxIbv" @@ -48,6 +48,11 @@ stages: # generation_length: 64 # time_between_requests: 2.0 # iterations: 5 + # - name: "pp64_g64" + # prompt_length: 64 + # generation_length: 64 + # time_between_requests: 2.0 + # iterations: 5 # - name: "pp64_g512" # prompt_length: 64 # generation_length: 512 @@ -58,6 +63,11 @@ stages: # generation_length: 64 # time_between_requests: 2.0 # iterations: 5 + - name: "pp256_g64" + prompt_length: 256 + generation_length: 64 + time_between_requests: 2.0 + iterations: 5 # - name: "pp256_g512" # prompt_length: 256 # generation_length: 512 @@ -83,26 +93,26 @@ stages: # generation_length: 512 # time_between_requests: 2.0 # iterations: 10 - - name: "pp4096_g64" - prompt_length: 4096 - generation_length: 64 - time_between_requests: 2.0 - iterations: 4 + # - name: "pp4096_g64" + # prompt_length: 4096 + # generation_length: 64 + # time_between_requests: 2.0 + # iterations: 4 # - name: "pp4096_g512" # prompt_length: 4096 # generation_length: 512 # time_between_requests: 2.0 # iterations: 10 - - name: "pp8192_g64" - prompt_length: 8192 - generation_length: 64 - time_between_requests: 2.0 - iterations: 4 + # - name: "pp8192_g64" + # prompt_length: 8192 + # generation_length: 64 + # time_between_requests: 2.0 + # iterations: 5 # - name: "pp8192_g512" # prompt_length: 8192 # generation_length: 512 # time_between_requests: 2.0 - # iterations: 10 + # iterations: 5 # - name: "pp16384_g64" # prompt_length: 16384 # generation_length: 64 diff --git a/.mlx_typings/mlx_lm/models/cache.pyi b/.mlx_typings/mlx_lm/models/cache.pyi index 30fe1b85..177dde3a 100644 --- a/.mlx_typings/mlx_lm/models/cache.pyi +++ b/.mlx_typings/mlx_lm/models/cache.pyi @@ -36,9 +36,7 @@ def save_prompt_cache( state. """ -def load_prompt_cache( - file_name, return_metadata=... -): # -> tuple[list[Any], Any] | list[Any]: +def load_prompt_cache(file_name: str, return_metadata=...) -> array: """ Load a prompt cache from a file. diff --git a/dashboard/index.html b/dashboard/index.html index 62ec32f5..d0ddc6fc 100644 --- a/dashboard/index.html +++ b/dashboard/index.html @@ -31,10 +31,10 @@ max-width: 1200px; margin-bottom: 15px; margin-top: 20px; - text-align: left; - display: flex; - justify-content: space-between; + display: grid; + grid-template-columns: 1fr auto 1fr; align-items: flex-end; + gap: 20px; } .dashboard-header h1 { @@ -67,6 +67,18 @@ flex-direction: column; } + .header-center { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + } + + .header-right { + display: flex; + justify-content: flex-end; + } + .header-instances-button { background-color: transparent; border: 1px solid var(--exo-medium-gray); @@ -972,11 +984,11 @@
- +
- +
@@ -986,16 +998,26 @@
- +
- +
+
+ +
+
+ + +
+
+
+
@@ -1004,10 +1026,15 @@
+ +
+

EXO logo

Fetching data...

- +
+ +
@@ -1068,12 +1095,14 @@ const modelSelect = document.getElementById('modelSelect'); const launchInstanceButton = document.getElementById('launchInstanceButton'); const launchStatus = document.getElementById('launchStatus'); + const minNodesOptions = document.getElementById('minNodesOptions'); const USE_MOCK_DATA = false; // <<< FLAG TO TOGGLE MOCK DATA let currentlySelectedNodeId = null; // To store the ID of the currently selected node let nodeIdToFriendlyName = {}; // Map nodeId -> friendly name for download sections let instanceIdToColor = {}; // Map instanceId -> color for visual coding let connectionToInstances = {}; // Map "nodeA|nodeB" -> [instanceIds] using that connection + let currentNodeCount = 1; // Track the current number of nodes in topology const API_ENDPOINT = window.location.origin + window.location.pathname.replace(/\/$/, "") + '/state'; const REFRESH_INTERVAL = 1000; // 1 second @@ -1218,6 +1247,16 @@ instancesMenuButton.classList.toggle('active', sidebarOpen); } + // Edge IP display flag (can be toggled from console) + window.exoShowEdgeIPs = false; + + // Helper function to toggle IP display (accessible from console) + window.toggleEdgeIPs = function() { + window.exoShowEdgeIPs = !window.exoShowEdgeIPs; + console.log(`Edge IP display ${window.exoShowEdgeIPs ? 'enabled' : 'disabled'}`); + return window.exoShowEdgeIPs; + }; + // Fetch available models and populate dropdown async function fetchAndPopulateModels() { try { @@ -1281,8 +1320,11 @@ const selectedSharding = document.querySelector('input[name="sharding"]:checked').value; const selectedInstanceMeta = document.querySelector('input[name="instance_meta"]:checked').value; + const minNodesRadio = document.querySelector('input[name="min_nodes"]:checked'); + const minNodes = minNodesRadio ? parseInt(minNodesRadio.value, 10) : 1; console.log("selectedSharding", selectedSharding); console.log("selectedInstanceMeta", selectedInstanceMeta); + console.log("minNodes", minNodes); try { showLaunchStatus('Launching instance...', 'loading'); @@ -1296,7 +1338,8 @@ body: JSON.stringify({ model_id: selectedModelId, sharding: selectedSharding, - instance_meta: selectedInstanceMeta + instance_meta: selectedInstanceMeta, + min_nodes: minNodes }) }); @@ -1858,6 +1901,39 @@ const edgesData = (topologyData && Array.isArray(topologyData.edges)) ? topologyData.edges : []; const nodeIds = Object.keys(nodesData); + // Update min nodes radio buttons based on current topology + currentNodeCount = Math.max(1, nodeIds.length); + if (minNodesOptions) { + // Get currently selected value before regenerating + const currentlySelected = document.querySelector('input[name="min_nodes"]:checked'); + const hasOnlyDefaultOption = minNodesOptions.children.length === 1; + // Default to maximum nodes on initial load, otherwise preserve user selection + const selectedValue = (currentlySelected && !hasOnlyDefaultOption) ? parseInt(currentlySelected.value, 10) : currentNodeCount; + + // Clear and regenerate radio buttons + minNodesOptions.innerHTML = ''; + for (let i = 1; i <= currentNodeCount; i++) { + const optionDiv = document.createElement('div'); + optionDiv.className = 'strategy-option'; + + const radio = document.createElement('input'); + radio.type = 'radio'; + radio.id = `minNodes${i}`; + radio.name = 'min_nodes'; + radio.value = i.toString(); + // Check if this should be selected (preserve selection or default to maximum) + radio.checked = (i === Math.min(selectedValue, currentNodeCount)); + + const label = document.createElement('label'); + label.htmlFor = `minNodes${i}`; + label.textContent = i.toString(); + + optionDiv.appendChild(radio); + optionDiv.appendChild(label); + minNodesOptions.appendChild(optionDiv); + } + } + if (nodeIds.length === 0) { const textEl = document.createElementNS('http://www.w3.org/2000/svg', 'text'); textEl.setAttribute('x', '50%'); @@ -2002,7 +2078,7 @@ arrowsGroup.appendChild(arrowSeg); // Add label for A->B direction (show all connections) - if (entry.aToBEdges && entry.aToBEdges.length > 0) { + if (window.exoShowEdgeIPs && entry.aToBEdges && entry.aToBEdges.length > 0) { // Count occurrences of each IP/interface combination const connectionCounts = new Map(); @@ -2067,7 +2143,7 @@ arrowsGroup.appendChild(arrowSeg); // Add label for B->A direction (show all connections) - if (entry.bToAEdges && entry.bToAEdges.length > 0) { + if (window.exoShowEdgeIPs && entry.bToAEdges && entry.bToAEdges.length > 0) { // Count occurrences of each IP/interface combination const connectionCounts = new Map(); diff --git a/src/exo/engines/mlx/auto_parallel.py b/src/exo/engines/mlx/auto_parallel.py index 345454db..4ff747b8 100644 --- a/src/exo/engines/mlx/auto_parallel.py +++ b/src/exo/engines/mlx/auto_parallel.py @@ -3,7 +3,10 @@ from functools import partial from inspect import signature from typing import TYPE_CHECKING, Callable, Protocol, cast, override -from mlx_lm.models.cache import KVCache, RotatingKVCache +from mlx_lm.models.cache import ( + KVCache, + _BaseCache, # pyright: ignore[reportPrivateUsage] +) from mlx_lm.models.deepseek_v3 import DeepseekV3MLP from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model from mlx_lm.models.llama import Model as LlamaModel @@ -91,7 +94,7 @@ class PipelineLastLayer(CustomMlxLayer): x, *args, **kwargs ).arguments.get("cache", None) - assert cache is None or isinstance(cache, (KVCache, RotatingKVCache)) + assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore output: mx.array = self.original_layer(x, *args, **kwargs) @@ -99,11 +102,7 @@ class PipelineLastLayer(CustomMlxLayer): output = mx.distributed.send( output, (self.r + 1) % self.s, group=self.group ) - if ( - cache is not None - and hasattr(cache, "keys") - and getattr(cache, "keys", None) is not None - ): + if cache is not None: # This change happened upstream - check out mlx github somewhere?? cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType] diff --git a/src/exo/engines/mlx/cache.py b/src/exo/engines/mlx/cache.py new file mode 100644 index 00000000..f4e7df8d --- /dev/null +++ b/src/exo/engines/mlx/cache.py @@ -0,0 +1,102 @@ +from copy import deepcopy +from typing import Callable + +from mlx_lm import stream_generate +from mlx_lm.models.cache import _BaseCache, trim_prompt_cache +from mlx_lm.tokenizer_utils import TokenizerWrapper + +import mlx.core as mx +from exo.engines.mlx import Model +from exo.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE +from exo.engines.mlx.utils_mlx import make_kv_cache + + +class KVPrefixCache: + def __init__(self): + # Only one prefix cache per runner. + self.prompts: list[mx.array] = [] # mx array of tokens (ints) + self.caches: list[list[_BaseCache]] = [] + + def add_kv_cache( + self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache] + ): + tokenized_prompt = self.encode_prompt(tokenizer, prompt) + self.prompts.append(tokenized_prompt) + self.caches.append(deepcopy(cache)) + + def get_kv_cache( + self, + model: Model, + tokenizer: TokenizerWrapper, + sampler: Callable[[mx.array], mx.array], + prompt: str, + ) -> list[_BaseCache]: + tokenized_prompt = self.encode_prompt(tokenizer, prompt) + max_length = len(tokenized_prompt) + + best_snapshot_index, best_snapshot_length = None, 0 + + for i, cached_prompt in enumerate(self.prompts): + length = _get_prefix_length(tokenized_prompt, cached_prompt) + + if length == max_length: + return self.caches[i] + + if length > best_snapshot_length: + best_snapshot_index, best_snapshot_length = i, length + + if best_snapshot_index is not None: + prompt_cache = deepcopy(self.caches[best_snapshot_index]) + trim_prompt_cache(prompt_cache, max_length - best_snapshot_length) + tokenized_prompt = tokenized_prompt[best_snapshot_index:] + + else: + prompt_cache = make_kv_cache( + model, + # max_kv_size=MAX_KV_SIZE, + # keep=KEEP_KV_SIZE + ) + + prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache) + + return prompt_cache + + def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array: + add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token + ) + tokenized_prompt = tokenizer.encode( + prompt, add_special_tokens=add_special_tokens + ) + return mx.array(tokenized_prompt) + + +def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int: + n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE) + if n == 0: + return 0 + + equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32) + prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever + return int(mx.sum(prefix_mask).item()) + + +def prefill( + model: Model, + tokenizer: TokenizerWrapper, + sampler: Callable[[mx.array], mx.array], + prompt: mx.array, + cache: list[_BaseCache], +) -> None: + for _ in stream_generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=0, + sampler=sampler, + prompt_cache=cache, + prefill_step_size=2048, + kv_group_size=KV_GROUP_SIZE, + kv_bits=KV_BITS, + ): + pass diff --git a/src/exo/engines/mlx/constants.py b/src/exo/engines/mlx/constants.py new file mode 100644 index 00000000..c73d62d3 --- /dev/null +++ b/src/exo/engines/mlx/constants.py @@ -0,0 +1,17 @@ +# TODO: Do we want so many constants? + +KV_GROUP_SIZE = 32 +KV_BITS = None +ATTENTION_KV_BITS = 4 +MAX_TOKENS = 8192 +MAX_KV_SIZE = 3200 +KEEP_KV_SIZE = 1600 +QUANTIZE_MODEL_MODE = "affine" +CACHE_GROUP_SIZE = 64 +KV_CACHE_BITS = 8 +TEMPERATURE = 1.0 + +# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True +TRUST_REMOTE_CODE = True +# TODO: Do we really want this? +HIDE_THINKING = False diff --git a/src/exo/engines/mlx/utils_mlx.py b/src/exo/engines/mlx/utils_mlx.py index 5f42ca9c..8c48bd2e 100644 --- a/src/exo/engines/mlx/utils_mlx.py +++ b/src/exo/engines/mlx/utils_mlx.py @@ -1,8 +1,10 @@ import os import resource +import time from typing import Any, Callable, cast -from mlx_lm.models.cache import KVCache, RotatingKVCache +from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache +from mlx_lm.models.deepseek_v3 import DeepseekV3Model from mlx_lm.sample_utils import make_sampler from mlx_lm.tokenizer_utils import TokenizerWrapper @@ -22,6 +24,14 @@ from exo.engines.mlx.auto_parallel import ( pipeline_auto_parallel, tensor_auto_parallel, ) +from exo.engines.mlx.constants import ( + CACHE_GROUP_SIZE, + KV_CACHE_BITS, + PATCH_SYSTEM_PROMPT, + TEMPERATURE, + TRUST_REMOTE_CODE, +) +from exo.shared.types.api import ChatCompletionMessageText from exo.shared.types.common import Host from exo.shared.types.memory import Memory from exo.shared.types.tasks import ChatCompletionTaskParams @@ -44,7 +54,6 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096)) mlx_rank: None | int = None mlx_world_size: None | int = None - def mx_barrier(group: mx.distributed.Group | None = None): mx.eval( mx.distributed.all_sum( @@ -87,7 +96,7 @@ def mlx_distributed_init( - mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup - strict: if True, raise an error if the distributed backend is not available """ - rank = bound_instance.bound_shard().device_rank + rank = bound_instance.bound_shard.device_rank logger.info(f"Starting initialization for rank {rank}") # TODO: singleton instances @@ -136,33 +145,40 @@ def initialize_mlx( """ mx.random.seed(42) - set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard())) + set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard)) - sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7) + sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE) logger.info("Created a sampler") if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1: logger.info(f"Single device used for {bound_instance.instance}") - model_path = build_model_path(bound_instance.bound_shard().model_meta.model_id) - model, _ = load_model(model_path, strict=True) - # TODO: we should really make this opt-in, but Kimi requires trust_remote_code=True - tokenizer = cast( - TokenizerWrapper, - load_tokenizer( - model_path, - tokenizer_config_extra={"trust_remote_code": True}, - # TODO: HACK for Kimi K2 wrong eos token id - eos_token_ids=[163586] if "kimi-k2" in bound_instance.bound_shard().model_meta.model_id.lower() else None, - ), - ) - assert isinstance(tokenizer, TokenizerWrapper) + model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id) + start_time = time.perf_counter() + model, config = load_model(model_path, strict=True) + end_time = time.perf_counter() + logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s") + if isinstance(model.model, DeepseekV3Model): + pass + # model, config = quantize_model( + # model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE + # ) + + tokenizer = get_tokenizer(model_path, bound_instance.bound_shard) else: logger.info("Starting distributed init") group = mlx_distributed_init(bound_instance) - model, tokenizer = shard_and_load(bound_instance.bound_shard(), group=group) - set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard())) + start_time = time.perf_counter() + model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group) + end_time = time.perf_counter() + logger.info( + f"Time taken to shard and load model: {(end_time - start_time):.2f}s" + ) + + set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard)) + + logger.debug(model) return cast(Model, model), tokenizer, sampler @@ -174,20 +190,28 @@ def shard_and_load( model_path = build_model_path(shard_metadata.model_meta.model_id) model, config = load_model(model_path, lazy=True, strict=False) - logger.info(f"{config=}") + logger.debug(model) + if isinstance(model.model, DeepseekV3Model): + pass + # TODO: See if we should quantize the model. + # def is_attention_layer(path: str) -> bool: + # path = path.lower() + + # return "self_attn" in path and "layernorm" not in path + + + # def quant_predicate(path: str, module: nn.Module): + # if not isinstance(module, nn.Linear): + # return False + + # return is_attention_layer(path) + # model, config = quantize_model( + # model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE + # ) + assert isinstance(model, nn.Module) - # TODO: we should really make this opt-in, but Kimi requires trust_remote_code=True - tokenizer = cast( - TokenizerWrapper, - # TODO: HACK for Kimi K2 wrong eos token id - load_tokenizer( - model_path, - tokenizer_config_extra={"trust_remote_code": True}, - # TODO: HACK for Kimi K2 wrong eos token id - eos_token_ids=[163586] if "kimi-k2" in shard_metadata.model_meta.model_id.lower() else None, - ), - ) + tokenizer = get_tokenizer(model_path, shard_metadata) logger.info(f"Group size: {group.size()}, group rank: {group.rank()}") @@ -200,44 +224,63 @@ def shard_and_load( model = pipeline_auto_parallel(model, group, shard_metadata) mx.eval(model.parameters()) + + # TODO: Do we need this? mx.eval(model) + logger.debug("SHARDED") + logger.debug(model) + # Synchronize processes before generation to avoid timeout mx_barrier(group) return model, tokenizer +def get_tokenizer(model_path: str, shard_metadata: ShardMetadata): + tokenizer = cast( + TokenizerWrapper, + load_tokenizer( + model_path, + tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE}, + # TODO: HACK for Kimi K2 wrong eos token id + eos_token_ids=[163586] + if "kimi-k2" in shard_metadata.model_meta.model_id.lower() + else None, + ), + ) + assert isinstance(tokenizer, TokenizerWrapper) + + return tokenizer + + def apply_chat_template( tokenizer: TokenizerWrapper, chat_task_data: ChatCompletionTaskParams, ) -> str: # Now we can properly access the messages messages = chat_task_data.messages - messages_dicts: list[dict[str, Any]] = [msg.model_dump() for msg in messages] - # Filter out None values, keeping relevant keys for the model formatted_messages: list[dict[str, Any]] = [] - for message in messages_dicts: - filtered_message: dict[str, Any] = { - k: v - for k, v in message.items() # pyright: ignore[reportAny] - if v is not None - } + for i, message in enumerate(messages): + if isinstance(message.content, ChatCompletionMessageText): + message.content = message.content.text + if isinstance(message.content, list): + if len(message.content) != 1: + logger.warning("Received malformed prompt") + continue - # Verify we have required fields - if "role" not in filtered_message: - raise ValueError(f"Message missing 'role' field: {filtered_message}") - if "content" not in filtered_message and "thinking" not in filtered_message: - # If neither content nor thinking is present, skip this message + message.content = message.content[0].text + if message.content is None and message.thinking is None: continue - formatted_messages.append(filtered_message) - - messages_dicts = formatted_messages + # Null values are not valid when applying templates in tokenizer + formatted_messages.append( + {k: v for k, v in message.model_dump().items() if v is not None} + ) prompt: str = tokenizer.apply_chat_template( # type: ignore - messages_dicts, + formatted_messages, tokenize=False, add_generation_prompt=True, ) @@ -269,16 +312,23 @@ class NullKVCache(KVCache): def make_kv_cache( - model: Model, - max_kv_size: int | None = None, -) -> list[KVCache | RotatingKVCache]: + model: Model, max_kv_size: int | None = None, keep: int = 0 +) -> list[KVCache | RotatingKVCache | QuantizedKVCache]: assert hasattr(model, "layers") + if max_kv_size is None: - logger.info("Using default KV cache") - return [KVCache() for _ in model.layers] + if KV_CACHE_BITS is None: + logger.info("Using default KV cache") + return [KVCache() for _ in model.layers] + else: + logger.info("Using quantized KV cache") + return [ + QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS) + for _ in model.layers + ] else: - logger.info(f"Using rotating KV cache with {max_kv_size=}") - return [RotatingKVCache(max_size=max_kv_size) for _ in model.layers] + logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}") + return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers] def mlx_force_oom(size: int = 40000) -> None: diff --git a/src/exo/main.py b/src/exo/main.py index 110d44a6..b21434af 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -1,6 +1,6 @@ -import signal import argparse import multiprocessing as mp +import signal from dataclasses import dataclass from typing import Self @@ -8,14 +8,13 @@ import anyio from anyio.abc import TaskGroup from pydantic import PositiveInt -from exo.shared.logging import logger import exo.routing.topics as topics from exo.master.api import API # TODO: should API be in master? from exo.master.main import Master from exo.routing.router import Router, get_node_id_keypair from exo.shared.constants import EXO_LOG from exo.shared.election import Election, ElectionResult -from exo.shared.logging import logger_cleanup, logger_setup +from exo.shared.logging import logger, logger_cleanup, logger_setup from exo.shared.types.commands import KillCommand from exo.shared.types.common import NodeId, SessionId from exo.utils.channels import Receiver, channel @@ -119,6 +118,7 @@ class Node: # if this is our second call to shutdown, just sys.exit if self._tg.cancel_scope.cancel_called: import sys + sys.exit(1) self._tg.cancel_scope.cancel() @@ -208,7 +208,7 @@ class Node: def main(): args = Args.parse() - + mp.set_start_method("spawn") # TODO: Refactor the current verbosity system logger_setup(EXO_LOG, args.verbosity) diff --git a/src/exo/master/api.py b/src/exo/master/api.py index 22074064..69176792 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -14,6 +14,7 @@ from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from loguru import logger +from exo.engines.mlx.constants import HIDE_THINKING from exo.shared.apply import apply from exo.shared.election import ElectionMessage from exo.shared.models.model_cards import MODEL_CARDS @@ -45,6 +46,7 @@ from exo.shared.types.models import ModelMetadata from exo.shared.types.state import State from exo.shared.types.tasks import ChatCompletionTaskParams from exo.shared.types.worker.instances import Instance, InstanceId +from exo.utils.banner import print_startup_banner from exo.utils.channels import Receiver, Sender from exo.utils.event_buffer import OrderedBuffer @@ -171,9 +173,9 @@ class API: ) command = CreateInstance( - command_id=CommandId(), model_meta=model_meta, instance_meta=payload.instance_meta, + min_nodes=payload.min_nodes, sharding=payload.sharding, ) await self._send(command) @@ -194,7 +196,6 @@ class API: raise HTTPException(status_code=404, detail="Instance not found") command = DeleteInstance( - command_id=CommandId(), instance_id=instance_id, ) await self._send(command) @@ -212,17 +213,26 @@ class API: self._chat_completion_queues[command_id] = asyncio.Queue() finished = False + is_thinking = False while not finished: # TODO: how long should this timeout be? chunk = await asyncio.wait_for( self._chat_completion_queues[command_id].get(), timeout=600 ) assert isinstance(chunk, TokenChunk) + # TODO: Do we want this? + if HIDE_THINKING: + if chunk.text == "": + chunk.text = "\n" + if chunk.text == "": + chunk.text = "\n" chunk_response: ChatCompletionResponse = chunk_to_response( chunk, command_id ) logger.debug(f"chunk_response: {chunk_response}") - yield f"data: {chunk_response.model_dump_json()}\n\n" + + if not HIDE_THINKING or not is_thinking: + yield f"data: {chunk_response.model_dump_json()}\n\n" if chunk.finish_reason is not None: yield "data: [DONE]\n\n" @@ -244,31 +254,6 @@ class API: model_meta = await resolve_model_meta(payload.model) payload.model = model_meta.model_id - # Preprocess messages for GPT-OSS harmony format if needed - # TODO: This is slop surely we get rid - if "gpt-oss" in payload.model.lower(): - import re - - for message in payload.messages: - if message.content and "<|channel|>" in message.content: - # Parse harmony format tags - thinking_pattern = r"<\|channel\|>(.*?)(?=<\|message\|>|$)" - content_pattern = r"<\|message\|>(.*?)(?=<\|end\|>|$)" - - thinking_match = re.search( - thinking_pattern, message.content, re.DOTALL - ) - content_match = re.search( - content_pattern, message.content, re.DOTALL - ) - - if content_match: - # Extract the actual content - message.content = content_match.group(1).strip() - if thinking_match: - # Store thinking in the thinking field - message.thinking = thinking_match.group(1).strip() - if not any( instance.shard_assignments.model_id == payload.model for instance in self.state.instances.values() @@ -279,7 +264,6 @@ class API: ) command = ChatCompletion( - command_id=CommandId(), request_params=payload, ) await self._send(command) @@ -325,9 +309,19 @@ class API: tg.start_soon(uvicorn_server.serve) tg.start_soon(self._apply_state) tg.start_soon(self._pause_on_new_election) + tg.start_soon(self._print_banner_when_ready, uvicorn_server) self.command_sender.close() self.global_event_receiver.close() + async def _print_banner_when_ready(self, uvicorn_server: uvicorn.Server): + """Wait for the uvicorn server to be ready, then print the startup banner.""" + # TODO: Is this the best condition to check for? + # The point is this should log when exo is ready. + while not uvicorn_server.started: + await asyncio.sleep(0.1) + + print_startup_banner(self.port) + async def _apply_state(self): with self.global_event_receiver as events: async for f_event in events: diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 7f481cb5..5dadb5c3 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -209,7 +209,7 @@ class Master: event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage] - # TODO: SQL + # TODO: SQL <- What does this mean? self._event_log.append(event) await self._send_event(indexed) diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index fb49666f..7f345660 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -3,6 +3,8 @@ from collections.abc import Mapping from copy import deepcopy from typing import Sequence +from loguru import logger + from exo.master.placement_utils import ( filter_cycles_by_memory, get_hosts_from_subgraph, @@ -41,13 +43,13 @@ def get_instance_placements_after_create( tb_only: bool = False, ) -> dict[InstanceId, Instance]: all_nodes = list(topology.list_nodes()) - from loguru import logger logger.info("finding cycles:") cycles = topology.get_cycles() - logger.info(f"{cycles=}") singleton_cycles = [[node] for node in all_nodes] - candidate_cycles = cycles + singleton_cycles + candidate_cycles = list( + filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles) + ) cycles_with_sufficient_memory = filter_cycles_by_memory( candidate_cycles, command.model_meta.storage_size ) diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 1a4e7011..c96a8d35 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -210,27 +210,19 @@ def get_mlx_ibv_devices_matrix( if i == j: continue - # just for debugging for now... - for connection_ip in _find_connection_ip(node_i, node_j, cycle_digraph): - interface_name = _find_interface_name_for_ip(connection_ip, node_i) - logger.info( - f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}" - ) - - matrix[i][j] = "rdma_en3" # TODO: hack, for now it's always en3 - continue - - for connection_ip in _find_connection_ip(node_i, node_j, cycle_digraph): - # Set the first valid rmda i -> j connection - if there are multiple, we set essentially randomly - this is fine, the connection doesn't appear to have to be bidirectional - if ( - interface_name := _find_interface_name_for_ip( - connection_ip, - node_i, - ) - ) is not None: + # Find the IP J uses to talk to I + for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph): + # This is a local IP on I, which is attached to an interface: find that interface + if interface_name := _find_interface_name_for_ip(connection_ip, node_i): matrix[i][j] = interface_name + logger.info( + f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}" + ) break else: + logger.warning( + f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}" + ) raise ValueError( "Current ibv backend requires all-to-all rdma connections" ) @@ -246,8 +238,9 @@ def _find_connection_ip( """Find all IP addresses that connect node i to node j.""" for connection in cycle_digraph.list_connections(): if ( - connection.local_node_id == node_j.node_id - and connection.send_back_node_id == node_i.node_id + connection.local_node_id == node_i.node_id + and connection.send_back_node_id == node_j.node_id + # TODO: Check if we need this. and connection.send_back_multiaddr is not None ): yield connection.send_back_multiaddr.ip_address @@ -260,13 +253,13 @@ def _find_interface_name_for_ip( if node_info.node_profile is None: return None + logger.info(f"Searching {node_info.node_id} for ip {ip_address}:") for interface in node_info.node_profile.network_interfaces: - logger.info( - f"Checking interface {interface.name} for IP {interface.ip_address} == {ip_address}: {interface.ip_address == ip_address}" - ) if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]: continue + logger.info(f" | {interface.name}: {interface.ip_address}") if interface.ip_address == ip_address: + logger.info("Found") return f"rdma_{interface.name}" return None diff --git a/src/exo/master/tests/conftest.py b/src/exo/master/tests/conftest.py index 39aa2b31..9ebfa152 100644 --- a/src/exo/master/tests/conftest.py +++ b/src/exo/master/tests/conftest.py @@ -41,11 +41,15 @@ def create_node(): @pytest.fixture def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]: port_counter = 1235 + ip_counter = 1 def _create_connection( source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None ) -> Connection: nonlocal port_counter + nonlocal ip_counter + # assign unique ips + ip_counter += 1 if send_back_port is None: send_back_port = port_counter port_counter += 1 @@ -53,7 +57,7 @@ def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]: local_node_id=source_node_id, send_back_node_id=sink_node_id, send_back_multiaddr=Multiaddr( - address=f"/ip4/169.254.0.1/tcp/{send_back_port}" + address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}" ), connection_profile=ConnectionProfile( throughput=1000, latency=1000, jitter=1000 diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index a8b33e8e..c52b0b33 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -1,6 +1,7 @@ from typing import Callable import pytest +from loguru import logger from exo.master.placement import ( get_instance_placements_after_create, @@ -356,10 +357,18 @@ def test_tensor_rdma_backend_connectivity_matrix( conn_b_c = create_connection(node_id_b, node_id_c) conn_c_a = create_connection(node_id_c, node_id_a) + conn_b_a = create_connection(node_id_b, node_id_a) + conn_c_b = create_connection(node_id_c, node_id_b) + conn_a_c = create_connection(node_id_a, node_id_c) + assert conn_a_b.send_back_multiaddr is not None assert conn_b_c.send_back_multiaddr is not None assert conn_c_a.send_back_multiaddr is not None + assert conn_b_a.send_back_multiaddr is not None + assert conn_c_b.send_back_multiaddr is not None + assert conn_a_c.send_back_multiaddr is not None + node_a.node_profile = NodePerformanceProfile( model_id="test", chip_id="test", @@ -368,7 +377,12 @@ def test_tensor_rdma_backend_connectivity_matrix( network_interfaces=[ NetworkInterfaceInfo( name="en3", - ip_address=conn_a_b.send_back_multiaddr.ip_address, + ip_address=conn_c_a.send_back_multiaddr.ip_address, + type="rdma", + ), + NetworkInterfaceInfo( + name="en4", + ip_address=conn_b_a.send_back_multiaddr.ip_address, type="rdma", ), ethernet_interface, @@ -381,9 +395,14 @@ def test_tensor_rdma_backend_connectivity_matrix( friendly_name="test", memory=node_b.node_profile.memory, network_interfaces=[ + NetworkInterfaceInfo( + name="en3", + ip_address=conn_c_b.send_back_multiaddr.ip_address, + type="rdma", + ), NetworkInterfaceInfo( name="en4", - ip_address=conn_b_c.send_back_multiaddr.ip_address, + ip_address=conn_a_b.send_back_multiaddr.ip_address, type="rdma", ), ethernet_interface, @@ -397,8 +416,13 @@ def test_tensor_rdma_backend_connectivity_matrix( memory=node_c.node_profile.memory, network_interfaces=[ NetworkInterfaceInfo( - name="en5", - ip_address=conn_c_a.send_back_multiaddr.ip_address, + name="en3", + ip_address=conn_a_c.send_back_multiaddr.ip_address, + type="rdma", + ), + NetworkInterfaceInfo( + name="en4", + ip_address=conn_b_c.send_back_multiaddr.ip_address, type="rdma", ), ethernet_interface, @@ -412,6 +436,9 @@ def test_tensor_rdma_backend_connectivity_matrix( topology.add_connection(conn_a_b) topology.add_connection(conn_b_c) topology.add_connection(conn_c_a) + topology.add_connection(conn_b_a) + topology.add_connection(conn_c_b) + topology.add_connection(conn_a_c) create_instance_command = CreateInstance( command_id=CommandId(), @@ -444,9 +471,11 @@ def test_tensor_rdma_backend_connectivity_matrix( idx_b = node_to_idx[node_id_b] idx_c = node_to_idx[node_id_c] - assert matrix[idx_a][idx_b] == "rdma_en3" - assert matrix[idx_b][idx_c] == "rdma_en4" - assert matrix[idx_c][idx_a] == "rdma_en5" + logger.info(matrix) + + assert matrix[idx_a][idx_b] == "rdma_en4" + assert matrix[idx_b][idx_c] == "rdma_en3" + assert matrix[idx_c][idx_a] == "rdma_en3" assert ":" in instance.mlx_ibv_coordinator assert not instance.mlx_ibv_coordinator.startswith("169.254") diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index 16cc6adb..6ea031a7 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -252,9 +252,5 @@ def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> Sta if not topology.contains_connection(event.edge): return state topology.remove_connection(event.edge) - if not topology.contains_connection(event.edge) and topology.contains_connection( - event.edge.reverse() - ): - topology.remove_connection(event.edge.reverse()) # TODO: Clean up removing the reverse connection return state.model_copy(update={"topology": topology}) diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 12051b3b..6368a72d 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -14,32 +14,32 @@ class ModelCard(CamelCaseModel): MODEL_CARDS: dict[str, ModelCard] = { # deepseek v3 - "deepseek-v3-0324:4bit": ModelCard( - short_id="deepseek-v3-0324:4bit", - model_id="mlx-community/DeepSeek-V3-0324-4bit", - name="DeepSeek V3 0324 (4-bit)", - description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""", - tags=[], - metadata=ModelMetadata( - model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"), - pretty_name="DeepSeek V3 0324 (4-bit)", - storage_size=Memory.from_kb(409706307), - n_layers=61, - ), - ), - "deepseek-v3-0324": ModelCard( - short_id="deepseek-v3-0324", - model_id="mlx-community/DeepSeek-v3-0324-8bit", - name="DeepSeek V3 0324 (8-bit)", - description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""", - tags=[], - metadata=ModelMetadata( - model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"), - pretty_name="DeepSeek V3 0324 (8-bit)", - storage_size=Memory.from_kb(754706307), - n_layers=61, - ), - ), + # "deepseek-v3-0324:4bit": ModelCard( + # short_id="deepseek-v3-0324:4bit", + # model_id="mlx-community/DeepSeek-V3-0324-4bit", + # name="DeepSeek V3 0324 (4-bit)", + # description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""", + # tags=[], + # metadata=ModelMetadata( + # model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"), + # pretty_name="DeepSeek V3 0324 (4-bit)", + # storage_size=Memory.from_kb(409706307), + # n_layers=61, + # ), + # ), + # "deepseek-v3-0324": ModelCard( + # short_id="deepseek-v3-0324", + # model_id="mlx-community/DeepSeek-v3-0324-8bit", + # name="DeepSeek V3 0324 (8-bit)", + # description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""", + # tags=[], + # metadata=ModelMetadata( + # model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"), + # pretty_name="DeepSeek V3 0324 (8-bit)", + # storage_size=Memory.from_kb(754706307), + # n_layers=61, + # ), + # ), "deepseek-v3.1": ModelCard( short_id="deepseek-v3.1", model_id="mlx-community/DeepSeek-V3.1-8bit", @@ -67,32 +67,32 @@ MODEL_CARDS: dict[str, ModelCard] = { ), ), # deepseek r1 - "deepseek-r1-0528:4bit": ModelCard( - short_id="deepseek-r1-0528:4bit", - model_id="mlx-community/DeepSeek-R1-0528-4bit", - name="DeepSeek-R1-0528 (4-bit)", - description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""", - tags=[], - metadata=ModelMetadata( - model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"), - pretty_name="DeepSeek R1 671B (4-bit)", - storage_size=Memory.from_kb(409706307), - n_layers=61, - ), - ), - "deepseek-r1-0528": ModelCard( - short_id="deepseek-r1-0528", - model_id="mlx-community/DeepSeek-R1-0528-8bit", - name="DeepSeek-R1-0528 (8-bit)", - description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""", - tags=[], - metadata=ModelMetadata( - model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"), - pretty_name="DeepSeek R1 671B (8-bit)", - storage_size=Memory.from_bytes(754998771712), - n_layers=61, - ), - ), + # "deepseek-r1-0528:4bit": ModelCard( + # short_id="deepseek-r1-0528:4bit", + # model_id="mlx-community/DeepSeek-R1-0528-4bit", + # name="DeepSeek-R1-0528 (4-bit)", + # description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""", + # tags=[], + # metadata=ModelMetadata( + # model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"), + # pretty_name="DeepSeek R1 671B (4-bit)", + # storage_size=Memory.from_kb(409706307), + # n_layers=61, + # ), + # ), + # "deepseek-r1-0528": ModelCard( + # short_id="deepseek-r1-0528", + # model_id="mlx-community/DeepSeek-R1-0528-8bit", + # name="DeepSeek-R1-0528 (8-bit)", + # description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""", + # tags=[], + # metadata=ModelMetadata( + # model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"), + # pretty_name="DeepSeek R1 671B (8-bit)", + # storage_size=Memory.from_bytes(754998771712), + # n_layers=61, + # ), + # ), # kimi k2 "kimi-k2-instruct-4bit": ModelCard( short_id="kimi-k2-instruct-4bit", @@ -228,19 +228,19 @@ MODEL_CARDS: dict[str, ModelCard] = { n_layers=32, ), ), - "phi-3-mini:128k": ModelCard( - short_id="phi-3-mini:128k", - model_id="mlx-community/Phi-3-mini-128k-instruct-4bit", - name="Phi 3 Mini 128k", - description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""", - tags=[], - metadata=ModelMetadata( - model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"), - pretty_name="Phi 3 Mini 128k", - storage_size=Memory.from_kb(2099262), - n_layers=32, - ), - ), + # "phi-3-mini:128k": ModelCard( + # short_id="phi-3-mini:128k", + # model_id="mlx-community/Phi-3-mini-128k-instruct-4bit", + # name="Phi 3 Mini 128k", + # description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""", + # tags=[], + # metadata=ModelMetadata( + # model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"), + # pretty_name="Phi 3 Mini 128k", + # storage_size=Memory.from_kb(2099262), + # n_layers=32, + # ), + # ), # qwen3 "qwen3-0.6b": ModelCard( short_id="qwen3-0.6b", @@ -268,19 +268,19 @@ MODEL_CARDS: dict[str, ModelCard] = { n_layers=48, ), ), - "qwen3-235b-a22b": ModelCard( - short_id="qwen3-235b-a22b", - model_id="mlx-community/Qwen3-235B-A22B-4bit", - name="Qwen3 235B, Active 22B (4-bit)", - description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""", - tags=[], - metadata=ModelMetadata( - model_id=ModelId("mlx-community/Qwen3-235B-A22B-4bit"), - pretty_name="Qwen3 235B, Active 22B (4-bit)", - storage_size=Memory.from_kb(123207680), - n_layers=94, - ), - ), + # "qwen3-235b-a22b": ModelCard( + # short_id="qwen3-235b-a22b", + # model_id="mlx-community/Qwen3-235B-A22B-4bit", + # name="Qwen3 235B, Active 22B (4-bit)", + # description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""", + # tags=[], + # metadata=ModelMetadata( + # model_id=ModelId("mlx-community/Qwen3-235B-A22B-4bit"), + # pretty_name="Qwen3 235B, Active 22B (4-bit)", + # storage_size=Memory.from_kb(123207680), + # n_layers=94, + # ), + # ), "qwen3-235b-a22b-8bit": ModelCard( short_id="qwen3-235b-a22b-8bit", model_id="mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit", @@ -308,31 +308,31 @@ MODEL_CARDS: dict[str, ModelCard] = { n_layers=40, ), ), - "granite-3.3-8b": ModelCard( - short_id="granite-3.3-8b", - model_id="mlx-community/granite-3.3-8b-instruct-fp16", - name="Granite 3.3 8B", - description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""", - tags=[], - metadata=ModelMetadata( - model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"), - pretty_name="Granite 3.3 8B", - storage_size=Memory.from_kb(15958720), - n_layers=40, - ), - ), + # "granite-3.3-8b": ModelCard( + # short_id="granite-3.3-8b", + # model_id="mlx-community/granite-3.3-8b-instruct-fp16", + # name="Granite 3.3 8B", + # description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""", + # tags=[], + # metadata=ModelMetadata( + # model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"), + # pretty_name="Granite 3.3 8B", + # storage_size=Memory.from_kb(15958720), + # n_layers=40, + # ), + # ), # smol-lm - "smol-lm-135m": ModelCard( - short_id="smol-lm-135m", - model_id="mlx-community/SmolLM-135M-4bit", - name="Smol LM 135M", - description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """, - tags=[], - metadata=ModelMetadata( - model_id=ModelId("mlx-community/SmolLM-135M-4bit"), - pretty_name="Smol LM 135M", - storage_size=Memory.from_kb(73940), - n_layers=30, - ), - ), + # "smol-lm-135m": ModelCard( + # short_id="smol-lm-135m", + # model_id="mlx-community/SmolLM-135M-4bit", + # name="Smol LM 135M", + # description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """, + # tags=[], + # metadata=ModelMetadata( + # model_id=ModelId("mlx-community/SmolLM-135M-4bit"), + # pretty_name="Smol LM 135M", + # storage_size=Memory.from_kb(73940), + # n_layers=30, + # ), + # ), } diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py index c88e1f59..7413161f 100644 --- a/src/exo/shared/topology.py +++ b/src/exo/shared/topology.py @@ -70,6 +70,9 @@ class Topology: if connection.send_back_node_id not in self._node_id_to_rx_id_map: self.add_node(NodeInfo(node_id=connection.send_back_node_id)) + if connection in self._edge_id_to_rx_id_map: + return + src_id = self._node_id_to_rx_id_map[connection.local_node_id] sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id] diff --git a/src/exo/shared/types/api.py b/src/exo/shared/types/api.py index 131fc7e2..3ec61289 100644 --- a/src/exo/shared/types/api.py +++ b/src/exo/shared/types/api.py @@ -24,13 +24,20 @@ class ModelListModel(BaseModel): class ModelList(BaseModel): - object: str = "list" + object: Literal["list"] = "list" data: list[ModelListModel] +class ChatCompletionMessageText(BaseModel): + type: Literal["text"] = "text" + text: str + + class ChatCompletionMessage(BaseModel): role: Literal["system", "user", "assistant", "developer", "tool", "function"] - content: str | None = None + content: ( + str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None + ) = None thinking: str | None = None # Added for GPT-OSS harmony format support name: str | None = None tool_calls: list[dict[str, Any]] | None = None @@ -55,20 +62,6 @@ class Logprobs(BaseModel): content: list[LogprobsContentItem] | None = None -class StreamingChoiceResponse(BaseModel): - index: int - delta: ChatCompletionMessage - logprobs: Logprobs | None = None - finish_reason: FinishReason | None = None - - -class ChatCompletionChoice(BaseModel): - index: int - message: ChatCompletionMessage - logprobs: Logprobs | None = None - finish_reason: FinishReason | None = None - - class PromptTokensDetails(BaseModel): cached_tokens: int = 0 audio_tokens: int = 0 @@ -89,6 +82,21 @@ class Usage(BaseModel): completion_tokens_details: CompletionTokensDetails | None = None +class StreamingChoiceResponse(BaseModel): + index: int + delta: ChatCompletionMessage + logprobs: Logprobs | None = None + finish_reason: FinishReason | None = None + usage: Usage | None = None + + +class ChatCompletionChoice(BaseModel): + index: int + message: ChatCompletionMessage + logprobs: Logprobs | None = None + finish_reason: FinishReason | None = None + + class ChatCompletionResponse(BaseModel): id: str object: Literal["chat.completion"] = "chat.completion" @@ -125,8 +133,8 @@ class CreateInstanceTaskParams(BaseModel): # TODO: in future the user could specify a specific Instance, not just a model_id model_id: str sharding: Sharding = Sharding.Pipeline - # TODO: fix instance_meta: InstanceMeta = InstanceMeta.MlxRing + min_nodes: int = 1 class DeleteInstanceTaskParams(BaseModel): diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index 9ea2aa3f..1deca8ff 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -29,6 +29,7 @@ class CreateInstance(BaseCommand): model_meta: ModelMetadata sharding: Sharding instance_meta: InstanceMeta + min_nodes: int class DeleteInstance(BaseCommand): diff --git a/src/exo/shared/types/topology.py b/src/exo/shared/types/topology.py index 1695a98b..33d7c752 100644 --- a/src/exo/shared/types/topology.py +++ b/src/exo/shared/types/topology.py @@ -12,20 +12,17 @@ class NodeInfo(CamelCaseModel): class Connection(CamelCaseModel): local_node_id: NodeId send_back_node_id: NodeId - send_back_multiaddr: Multiaddr | None + send_back_multiaddr: Multiaddr connection_profile: ConnectionProfile | None = None def __hash__(self) -> int: - if self.send_back_multiaddr: - return hash( - ( - self.local_node_id, - self.send_back_node_id, - self.send_back_multiaddr.address, - ) + return hash( + ( + self.local_node_id, + self.send_back_node_id, + self.send_back_multiaddr.address, ) - else: - return hash((self.local_node_id, self.send_back_node_id)) + ) def __eq__(self, other: object) -> bool: if not isinstance(other, Connection): @@ -37,13 +34,4 @@ class Connection(CamelCaseModel): ) def is_thunderbolt(self) -> bool: - return self.send_back_multiaddr is not None and str( - self.send_back_multiaddr.ipv4_address - ).startswith("169.254") - - def reverse(self) -> "Connection": - return Connection( - local_node_id=self.send_back_node_id, - send_back_node_id=self.local_node_id, - send_back_multiaddr=None, - ) + return str(self.send_back_multiaddr.ipv4_address).startswith("200.0") diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index 9230001f..b68e60a4 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -41,6 +41,7 @@ class BoundInstance(CamelCaseModel): instance: Instance bound_runner_id: RunnerId + @property def bound_shard(self) -> ShardMetadata: shard = self.instance.shard(self.bound_runner_id) assert shard is not None diff --git a/src/exo/shared/types/worker/runners.py b/src/exo/shared/types/worker/runners.py index da8544a3..5cceb83b 100644 --- a/src/exo/shared/types/worker/runners.py +++ b/src/exo/shared/types/worker/runners.py @@ -48,6 +48,7 @@ class RunnerRunning(BaseRunnerStatus): class RunnerShutdown(BaseRunnerStatus): pass + class RunnerFailed(BaseRunnerStatus): error_message: str | None = None diff --git a/src/exo/utils/banner.py b/src/exo/utils/banner.py new file mode 100644 index 00000000..cae6eac3 --- /dev/null +++ b/src/exo/utils/banner.py @@ -0,0 +1,34 @@ +def print_startup_banner(port: int) -> None: + """Print a prominent startup banner with API endpoint information.""" + banner = """ +╔═══════════════════════════════════════════════════════════════════════╗ +║ ║ +║ ███████╗██╗ ██╗ ██████╗ ║ +║ ██╔════╝╚██╗██╔╝██╔═══██╗ ║ +║ █████╗ ╚███╔╝ ██║ ██║ ║ +║ ██╔══╝ ██╔██╗ ██║ ██║ ║ +║ ███████╗██╔╝ ██╗╚██████╔╝ ║ +║ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ║ +║ ║ +║ Distributed AI Inference Cluster ║ +║ ║ +╚═══════════════════════════════════════════════════════════════════════╝ +""" + + dashboard_url = f"http://localhost:{port}" + + api_info = f""" +╔═══════════════════════════════════════════════════════════════════════╗ +║ ║ +║ 🌐 Dashboard & API Ready ║ +║ ║ +║ {dashboard_url}{" " * (69 - len(dashboard_url))}║ +║ ║ +║ Click the URL above to open the dashboard in your browser ║ +║ ║ +╚═══════════════════════════════════════════════════════════════════════╝ +""" + + print(banner) + print(api_info) + print() diff --git a/src/exo/utils/channels.py b/src/exo/utils/channels.py index 70971cf3..c335fb02 100644 --- a/src/exo/utils/channels.py +++ b/src/exo/utils/channels.py @@ -139,7 +139,9 @@ class MpSender[T]: # == unique to Mp channels == def join(self) -> None: """Ensure any queued messages are resolved before continuing""" - assert self._state.closed.is_set(), "Mp channels must be closed before being joined" + assert self._state.closed.is_set(), ( + "Mp channels must be closed before being joined" + ) self._state.buffer.join_thread() # == context manager support == @@ -209,7 +211,9 @@ class MpReceiver[T]: # == unique to Mp channels == def join(self) -> None: """Block until all enqueued messages are drained off our side of the buffer""" - assert self._state.closed.is_set(), "Mp channels must be closed before being joined" + assert self._state.closed.is_set(), ( + "Mp channels must be closed before being joined" + ) self._state.buffer.join_thread() # == iterator support == diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index dfdda537..e44b1975 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -100,12 +100,12 @@ def _model_needs_download( for runner in runners.values(): if ( isinstance(runner.status, RunnerWaitingForModel) - and runner.bound_instance.bound_shard() not in download_status + and runner.bound_instance.bound_shard not in download_status ): # We don't invalidate download_status randomly in case a file gets deleted on disk return DownloadModel( instance_id=runner.bound_instance.instance.instance_id, - shard_metadata=runner.bound_instance.bound_shard(), + shard_metadata=runner.bound_instance.bound_shard, ) @@ -160,7 +160,7 @@ def _ready_to_warmup( ) for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard ) - and runner.bound_instance.bound_shard().device_rank != 0 + and runner.bound_instance.bound_shard.device_rank != 0 ) or ( all( @@ -170,7 +170,7 @@ def _ready_to_warmup( for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard if global_runner_id != runner.bound_instance.bound_runner_id ) - and runner.bound_instance.bound_shard().device_rank == 0 + and runner.bound_instance.bound_shard.device_rank == 0 ) ): return StartWarmup(instance_id=runner.bound_instance.instance.instance_id) diff --git a/src/exo/worker/runner/generate.py b/src/exo/worker/runner/generate.py index 09d51b6c..134ac956 100644 --- a/src/exo/worker/runner/generate.py +++ b/src/exo/worker/runner/generate.py @@ -6,6 +6,9 @@ from mlx_lm.models.cache import KVCache from mlx_lm.tokenizer_utils import TokenizerWrapper from exo.engines.mlx import Model + +# from exo.engines.mlx.cache import KVPrefixCache +from exo.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS from exo.engines.mlx.utils_mlx import ( apply_chat_template, make_kv_cache, @@ -70,6 +73,8 @@ def warmup_inference( sampler=sampler, prompt_cache=cache, prefill_step_size=65536, + kv_group_size=KV_GROUP_SIZE, + kv_bits=KV_BITS, ): logger.info("Generated warmup token: " + str(_r.text)) tokens_generated += 1 @@ -94,19 +99,19 @@ def mlx_generate( chat_task_data=task, ) - cache = make_kv_cache( - model=model, - ) + caches = make_kv_cache(model=model) - max_tokens = task.max_tokens or 1000 + max_tokens = task.max_tokens or MAX_TOKENS for out in stream_generate( model=model, tokenizer=tokenizer, prompt=prompt, max_tokens=max_tokens, sampler=sampler, - prompt_cache=cache, + prompt_cache=caches, prefill_step_size=65536, + kv_group_size=KV_GROUP_SIZE, + kv_bits=KV_BITS, ): logger.info(out.text) if out.finish_reason is not None and out.finish_reason not in get_args( diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index f2b23e35..87eb742d 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -22,7 +22,7 @@ from exo.shared.types.tasks import ( ) from exo.shared.types.worker.commands_runner import ( GenerationResponse, - TokenizedResponse, + # TokenizedResponse, ) from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.runners import ( @@ -31,12 +31,12 @@ from exo.shared.types.worker.runners import ( RunnerLoading, RunnerReady, RunnerRunning, + RunnerShutdown, RunnerStatus, RunnerWaitingForModel, RunnerWarmingUp, - RunnerShutdown ) -from exo.utils.channels import MpReceiver, MpSender, ClosedResourceError +from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender from exo.worker.runner.bootstrap import logger from exo.worker.runner.generate import mlx_generate, warmup_inference @@ -49,7 +49,7 @@ def main( instance, runner_id, shard_metadata = ( bound_instance.instance, bound_instance.bound_runner_id, - bound_instance.bound_shard(), + bound_instance.bound_shard, ) try: logger.info("hello from the runner") @@ -115,6 +115,7 @@ def main( model=model, tokenizer=tokenizer, sampler=sampler, + # kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching ) logger.info(f"warmed up by generating {toks} tokens") logger.info( @@ -185,9 +186,9 @@ def main( ), ) ) - case TokenizedResponse(): - # TODO: something here ig - logger.info("Finished tokenizing?") + # case TokenizedResponse(): + # TODO: something here ig + # logger.info("Finished tokenizing?") current_status = RunnerReady() logger.info("runner ready") @@ -212,9 +213,7 @@ def main( ) ) event_sender.send( - RunnerStatusUpdated( - runner_id=runner_id, runner_status=RunnerShutdown() - ) + RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown()) ) except ClosedResourceError: logger.warning("runner communication closed unexpectedly") diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index 768cefa8..cda356ae 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -67,7 +67,7 @@ class RunnerSupervisor: daemon=True, ) - shard_metadata = bound_instance.bound_shard() + shard_metadata = bound_instance.bound_shard self = cls( bound_instance=bound_instance, @@ -109,12 +109,13 @@ class RunnerSupervisor: if not self.runner_process.is_alive(): return - logger.critical("Runner process didn't respond to SIGKILL. System resources may have leaked") + logger.critical( + "Runner process didn't respond to SIGKILL. System resources may have leaked" + ) def shutdown(self): assert self._tg self._tg.cancel_scope.cancel() - async def start_task(self, task: Task): event = anyio.Event() @@ -126,7 +127,6 @@ class RunnerSupervisor: return await event.wait() - async def _forward_events(self): with self._ev_recv as events: try: @@ -140,7 +140,6 @@ class RunnerSupervisor: except (ClosedResourceError, BrokenResourceError) as e: await self._check_runner(e) - def __del__(self) -> None: if self.runner_process.is_alive(): logger.warning("RunnerSupervisor was not stopped cleanly.") @@ -152,7 +151,7 @@ class RunnerSupervisor: await to_thread.run_sync(self.runner_process.join, 1) rc = self.runner_process.exitcode if rc == 0: - # + # return if isinstance(rc, int) and rc < 0: diff --git a/src/exo/worker/tests/test_plan/test_worker_plan.py b/src/exo/worker/tests/test_plan/test_worker_plan.py index 02f9612d..c555edd4 100644 --- a/src/exo/worker/tests/test_plan/test_worker_plan.py +++ b/src/exo/worker/tests/test_plan/test_worker_plan.py @@ -1,4 +1,5 @@ import pytest +from exo.worker.common import AssignedRunner from exo.shared.types.api import ChatCompletionMessage from exo.shared.types.state import State @@ -27,7 +28,6 @@ from exo.shared.types.worker.runners import ( RunningRunnerStatus, ) from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.common import AssignedRunner from exo.worker.main import Worker from exo.worker.plan import plan from exo.worker.tests.constants import ( diff --git a/tmp/run_llm.sh b/tmp/run_llm.sh index 07599c2d..b9dbb61b 100755 --- a/tmp/run_llm.sh +++ b/tmp/run_llm.sh @@ -13,9 +13,9 @@ QUERY="$*" curl -sN -X POST "http://$HOST:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ -d "{ - \"model\": \"mlx-community/Llama-3.3-70B-Instruct-8bit\", + \"model\": \"mlx-community/Kimi-K2-Thinking\", \"stream\": true, - \"messages\": [{ \"role\": \"user\", \"content\": \"$QUERY\" }] + \"messages\": [{ \"role\": \"user\", \"content\": \"$QUERY\"}] }" | grep --line-buffered '^data:' | grep --line-buffered -v 'data: \[DONE\]' |