mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Demo
Co-authored-by: Evan <evanev7@gmail.com> Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
This commit is contained in:
34
.github/configs/bench_simple.yaml
vendored
34
.github/configs/bench_simple.yaml
vendored
@@ -26,7 +26,7 @@ model_ids:
|
||||
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Sharding strategy: "Pipeline" or "Tensor"
|
||||
sharding: "Pipeline"
|
||||
sharding: "Tensor"
|
||||
|
||||
# Instance type: "MlxRing" or "MlxIbv"
|
||||
instance_meta: "MlxIbv"
|
||||
@@ -48,6 +48,11 @@ stages:
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g512"
|
||||
# prompt_length: 64
|
||||
# generation_length: 512
|
||||
@@ -58,6 +63,11 @@ stages:
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
- name: "pp256_g64"
|
||||
prompt_length: 256
|
||||
generation_length: 64
|
||||
time_between_requests: 2.0
|
||||
iterations: 5
|
||||
# - name: "pp256_g512"
|
||||
# prompt_length: 256
|
||||
# generation_length: 512
|
||||
@@ -83,26 +93,26 @@ stages:
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
- name: "pp4096_g64"
|
||||
prompt_length: 4096
|
||||
generation_length: 64
|
||||
time_between_requests: 2.0
|
||||
iterations: 4
|
||||
# - name: "pp4096_g64"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 4
|
||||
# - name: "pp4096_g512"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
- name: "pp8192_g64"
|
||||
prompt_length: 8192
|
||||
generation_length: 64
|
||||
time_between_requests: 2.0
|
||||
iterations: 4
|
||||
# - name: "pp8192_g64"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp8192_g512"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# iterations: 5
|
||||
# - name: "pp16384_g64"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 64
|
||||
|
||||
@@ -36,9 +36,7 @@ def save_prompt_cache(
|
||||
state.
|
||||
"""
|
||||
|
||||
def load_prompt_cache(
|
||||
file_name, return_metadata=...
|
||||
): # -> tuple[list[Any], Any] | list[Any]:
|
||||
def load_prompt_cache(file_name: str, return_metadata=...) -> array:
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@
|
||||
max-width: 1200px;
|
||||
margin-bottom: 15px;
|
||||
margin-top: 20px;
|
||||
text-align: left;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
display: grid;
|
||||
grid-template-columns: 1fr auto 1fr;
|
||||
align-items: flex-end;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.dashboard-header h1 {
|
||||
@@ -67,6 +67,18 @@
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.header-center {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.header-right {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.header-instances-button {
|
||||
background-color: transparent;
|
||||
border: 1px solid var(--exo-medium-gray);
|
||||
@@ -972,11 +984,11 @@
|
||||
<label class="launch-label">Sharding:</label>
|
||||
<div class="strategy-options">
|
||||
<div class="strategy-option">
|
||||
<input type="radio" id="shardingPipeline" name="sharding" value="Pipeline" checked>
|
||||
<input type="radio" id="shardingPipeline" name="sharding" value="Pipeline">
|
||||
<label for="shardingPipeline">Pipeline</label>
|
||||
</div>
|
||||
<div class="strategy-option">
|
||||
<input type="radio" id="shardingTensor" name="sharding" value="Tensor">
|
||||
<input type="radio" id="shardingTensor" name="sharding" value="Tensor" checked>
|
||||
<label for="shardingTensor">Tensor</label>
|
||||
</div>
|
||||
</div>
|
||||
@@ -986,16 +998,26 @@
|
||||
<label class="launch-label">Instance Type:</label>
|
||||
<div class="strategy-options">
|
||||
<div class="strategy-option">
|
||||
<input type="radio" id="instanceMlxRing" name="instance_meta" value="MlxRing" checked>
|
||||
<input type="radio" id="instanceMlxRing" name="instance_meta" value="MlxRing">
|
||||
<label for="instanceMlxRing">MLX Ring</label>
|
||||
</div>
|
||||
<div class="strategy-option">
|
||||
<input type="radio" id="instanceMlxIbv" name="instance_meta" value="MlxIbv">
|
||||
<input type="radio" id="instanceMlxIbv" name="instance_meta" value="MlxIbv" checked>
|
||||
<label for="instanceMlxIbv">MLX IBV</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="strategy-selector">
|
||||
<label class="launch-label">Minimum Nodes:</label>
|
||||
<div class="strategy-options" id="minNodesOptions">
|
||||
<div class="strategy-option">
|
||||
<input type="radio" id="minNodes1" name="min_nodes" value="1" checked>
|
||||
<label for="minNodes1">1</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button id="launchInstanceButton" class="launch-button" disabled>Launch Instance</button>
|
||||
<div id="launchStatus" class="launch-status"></div>
|
||||
</div>
|
||||
@@ -1004,10 +1026,15 @@
|
||||
|
||||
<div class="dashboard-header">
|
||||
<div class="header-left">
|
||||
<!-- Left section: empty or can be used for future content -->
|
||||
</div>
|
||||
<div class="header-center">
|
||||
<h1><img src="exo-logo.png" alt="EXO logo" height="48" /></h1>
|
||||
<p class="last-updated" id="lastUpdated">Fetching data...</p>
|
||||
</div>
|
||||
<button class="header-instances-button" id="instancesMenuButton">Instances</button>
|
||||
<div class="header-right">
|
||||
<button class="header-instances-button" id="instancesMenuButton">Instances</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Replaced node-grid with SVG container for topology graph -->
|
||||
@@ -1068,12 +1095,14 @@
|
||||
const modelSelect = document.getElementById('modelSelect');
|
||||
const launchInstanceButton = document.getElementById('launchInstanceButton');
|
||||
const launchStatus = document.getElementById('launchStatus');
|
||||
const minNodesOptions = document.getElementById('minNodesOptions');
|
||||
|
||||
const USE_MOCK_DATA = false; // <<< FLAG TO TOGGLE MOCK DATA
|
||||
let currentlySelectedNodeId = null; // To store the ID of the currently selected node
|
||||
let nodeIdToFriendlyName = {}; // Map nodeId -> friendly name for download sections
|
||||
let instanceIdToColor = {}; // Map instanceId -> color for visual coding
|
||||
let connectionToInstances = {}; // Map "nodeA|nodeB" -> [instanceIds] using that connection
|
||||
let currentNodeCount = 1; // Track the current number of nodes in topology
|
||||
|
||||
const API_ENDPOINT = window.location.origin + window.location.pathname.replace(/\/$/, "") + '/state';
|
||||
const REFRESH_INTERVAL = 1000; // 1 second
|
||||
@@ -1218,6 +1247,16 @@
|
||||
instancesMenuButton.classList.toggle('active', sidebarOpen);
|
||||
}
|
||||
|
||||
// Edge IP display flag (can be toggled from console)
|
||||
window.exoShowEdgeIPs = false;
|
||||
|
||||
// Helper function to toggle IP display (accessible from console)
|
||||
window.toggleEdgeIPs = function() {
|
||||
window.exoShowEdgeIPs = !window.exoShowEdgeIPs;
|
||||
console.log(`Edge IP display ${window.exoShowEdgeIPs ? 'enabled' : 'disabled'}`);
|
||||
return window.exoShowEdgeIPs;
|
||||
};
|
||||
|
||||
// Fetch available models and populate dropdown
|
||||
async function fetchAndPopulateModels() {
|
||||
try {
|
||||
@@ -1281,8 +1320,11 @@
|
||||
|
||||
const selectedSharding = document.querySelector('input[name="sharding"]:checked').value;
|
||||
const selectedInstanceMeta = document.querySelector('input[name="instance_meta"]:checked').value;
|
||||
const minNodesRadio = document.querySelector('input[name="min_nodes"]:checked');
|
||||
const minNodes = minNodesRadio ? parseInt(minNodesRadio.value, 10) : 1;
|
||||
console.log("selectedSharding", selectedSharding);
|
||||
console.log("selectedInstanceMeta", selectedInstanceMeta);
|
||||
console.log("minNodes", minNodes);
|
||||
|
||||
try {
|
||||
showLaunchStatus('Launching instance...', 'loading');
|
||||
@@ -1296,7 +1338,8 @@
|
||||
body: JSON.stringify({
|
||||
model_id: selectedModelId,
|
||||
sharding: selectedSharding,
|
||||
instance_meta: selectedInstanceMeta
|
||||
instance_meta: selectedInstanceMeta,
|
||||
min_nodes: minNodes
|
||||
})
|
||||
});
|
||||
|
||||
@@ -1858,6 +1901,39 @@
|
||||
const edgesData = (topologyData && Array.isArray(topologyData.edges)) ? topologyData.edges : [];
|
||||
const nodeIds = Object.keys(nodesData);
|
||||
|
||||
// Update min nodes radio buttons based on current topology
|
||||
currentNodeCount = Math.max(1, nodeIds.length);
|
||||
if (minNodesOptions) {
|
||||
// Get currently selected value before regenerating
|
||||
const currentlySelected = document.querySelector('input[name="min_nodes"]:checked');
|
||||
const hasOnlyDefaultOption = minNodesOptions.children.length === 1;
|
||||
// Default to maximum nodes on initial load, otherwise preserve user selection
|
||||
const selectedValue = (currentlySelected && !hasOnlyDefaultOption) ? parseInt(currentlySelected.value, 10) : currentNodeCount;
|
||||
|
||||
// Clear and regenerate radio buttons
|
||||
minNodesOptions.innerHTML = '';
|
||||
for (let i = 1; i <= currentNodeCount; i++) {
|
||||
const optionDiv = document.createElement('div');
|
||||
optionDiv.className = 'strategy-option';
|
||||
|
||||
const radio = document.createElement('input');
|
||||
radio.type = 'radio';
|
||||
radio.id = `minNodes${i}`;
|
||||
radio.name = 'min_nodes';
|
||||
radio.value = i.toString();
|
||||
// Check if this should be selected (preserve selection or default to maximum)
|
||||
radio.checked = (i === Math.min(selectedValue, currentNodeCount));
|
||||
|
||||
const label = document.createElement('label');
|
||||
label.htmlFor = `minNodes${i}`;
|
||||
label.textContent = i.toString();
|
||||
|
||||
optionDiv.appendChild(radio);
|
||||
optionDiv.appendChild(label);
|
||||
minNodesOptions.appendChild(optionDiv);
|
||||
}
|
||||
}
|
||||
|
||||
if (nodeIds.length === 0) {
|
||||
const textEl = document.createElementNS('http://www.w3.org/2000/svg', 'text');
|
||||
textEl.setAttribute('x', '50%');
|
||||
@@ -2002,7 +2078,7 @@
|
||||
arrowsGroup.appendChild(arrowSeg);
|
||||
|
||||
// Add label for A->B direction (show all connections)
|
||||
if (entry.aToBEdges && entry.aToBEdges.length > 0) {
|
||||
if (window.exoShowEdgeIPs && entry.aToBEdges && entry.aToBEdges.length > 0) {
|
||||
// Count occurrences of each IP/interface combination
|
||||
const connectionCounts = new Map();
|
||||
|
||||
@@ -2067,7 +2143,7 @@
|
||||
arrowsGroup.appendChild(arrowSeg);
|
||||
|
||||
// Add label for B->A direction (show all connections)
|
||||
if (entry.bToAEdges && entry.bToAEdges.length > 0) {
|
||||
if (window.exoShowEdgeIPs && entry.bToAEdges && entry.bToAEdges.length > 0) {
|
||||
// Count occurrences of each IP/interface combination
|
||||
const connectionCounts = new Map();
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@ from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Callable, Protocol, cast, override
|
||||
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
_BaseCache, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
@@ -91,7 +94,7 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
assert cache is None or isinstance(cache, (KVCache, RotatingKVCache))
|
||||
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
@@ -99,11 +102,7 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if (
|
||||
cache is not None
|
||||
and hasattr(cache, "keys")
|
||||
and getattr(cache, "keys", None) is not None
|
||||
):
|
||||
if cache is not None:
|
||||
# This change happened upstream - check out mlx github somewhere??
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
|
||||
102
src/exo/engines/mlx/cache.py
Normal file
102
src/exo/engines/mlx/cache.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
import mlx.core as mx
|
||||
from exo.engines.mlx import Model
|
||||
from exo.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.engines.mlx.utils_mlx import make_kv_cache
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
):
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
return self.caches[i]
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
)
|
||||
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
|
||||
return prompt_cache
|
||||
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
17
src/exo/engines/mlx/constants.py
Normal file
17
src/exo/engines/mlx/constants.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# TODO: Do we want so many constants?
|
||||
|
||||
KV_GROUP_SIZE = 32
|
||||
KV_BITS = None
|
||||
ATTENTION_KV_BITS = 4
|
||||
MAX_TOKENS = 8192
|
||||
MAX_KV_SIZE = 3200
|
||||
KEEP_KV_SIZE = 1600
|
||||
QUANTIZE_MODEL_MODE = "affine"
|
||||
CACHE_GROUP_SIZE = 64
|
||||
KV_CACHE_BITS = 8
|
||||
TEMPERATURE = 1.0
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE = True
|
||||
# TODO: Do we really want this?
|
||||
HIDE_THINKING = False
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import resource
|
||||
import time
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
@@ -22,6 +24,14 @@ from exo.engines.mlx.auto_parallel import (
|
||||
pipeline_auto_parallel,
|
||||
tensor_auto_parallel,
|
||||
)
|
||||
from exo.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
PATCH_SYSTEM_PROMPT,
|
||||
TEMPERATURE,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
@@ -44,7 +54,6 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
mlx_rank: None | int = None
|
||||
mlx_world_size: None | int = None
|
||||
|
||||
|
||||
def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -87,7 +96,7 @@ def mlx_distributed_init(
|
||||
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
|
||||
- strict: if True, raise an error if the distributed backend is not available
|
||||
"""
|
||||
rank = bound_instance.bound_shard().device_rank
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
# TODO: singleton instances
|
||||
@@ -136,33 +145,40 @@ def initialize_mlx(
|
||||
"""
|
||||
mx.random.seed(42)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard()))
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
|
||||
logger.info("Created a sampler")
|
||||
|
||||
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard().model_meta.model_id)
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
# TODO: we should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
tokenizer = cast(
|
||||
TokenizerWrapper,
|
||||
load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": True},
|
||||
# TODO: HACK for Kimi K2 wrong eos token id
|
||||
eos_token_ids=[163586] if "kimi-k2" in bound_instance.bound_shard().model_meta.model_id.lower() else None,
|
||||
),
|
||||
)
|
||||
assert isinstance(tokenizer, TokenizerWrapper)
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, config = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
if isinstance(model.model, DeepseekV3Model):
|
||||
pass
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard(), group=group)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard()))
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
logger.debug(model)
|
||||
|
||||
return cast(Model, model), tokenizer, sampler
|
||||
|
||||
@@ -174,20 +190,28 @@ def shard_and_load(
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
model, config = load_model(model_path, lazy=True, strict=False)
|
||||
logger.info(f"{config=}")
|
||||
logger.debug(model)
|
||||
if isinstance(model.model, DeepseekV3Model):
|
||||
pass
|
||||
# TODO: See if we should quantize the model.
|
||||
# def is_attention_layer(path: str) -> bool:
|
||||
# path = path.lower()
|
||||
|
||||
# return "self_attn" in path and "layernorm" not in path
|
||||
|
||||
|
||||
# def quant_predicate(path: str, module: nn.Module):
|
||||
# if not isinstance(module, nn.Linear):
|
||||
# return False
|
||||
|
||||
# return is_attention_layer(path)
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
# TODO: we should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
tokenizer = cast(
|
||||
TokenizerWrapper,
|
||||
# TODO: HACK for Kimi K2 wrong eos token id
|
||||
load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": True},
|
||||
# TODO: HACK for Kimi K2 wrong eos token id
|
||||
eos_token_ids=[163586] if "kimi-k2" in shard_metadata.model_meta.model_id.lower() else None,
|
||||
),
|
||||
)
|
||||
tokenizer = get_tokenizer(model_path, shard_metadata)
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
@@ -200,44 +224,63 @@ def shard_and_load(
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
logger.debug("SHARDED")
|
||||
logger.debug(model)
|
||||
|
||||
# Synchronize processes before generation to avoid timeout
|
||||
mx_barrier(group)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_tokenizer(model_path: str, shard_metadata: ShardMetadata):
|
||||
tokenizer = cast(
|
||||
TokenizerWrapper,
|
||||
load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
# TODO: HACK for Kimi K2 wrong eos token id
|
||||
eos_token_ids=[163586]
|
||||
if "kimi-k2" in shard_metadata.model_meta.model_id.lower()
|
||||
else None,
|
||||
),
|
||||
)
|
||||
assert isinstance(tokenizer, TokenizerWrapper)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
chat_task_data: ChatCompletionTaskParams,
|
||||
) -> str:
|
||||
# Now we can properly access the messages
|
||||
messages = chat_task_data.messages
|
||||
messages_dicts: list[dict[str, Any]] = [msg.model_dump() for msg in messages]
|
||||
|
||||
# Filter out None values, keeping relevant keys for the model
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
for message in messages_dicts:
|
||||
filtered_message: dict[str, Any] = {
|
||||
k: v
|
||||
for k, v in message.items() # pyright: ignore[reportAny]
|
||||
if v is not None
|
||||
}
|
||||
for i, message in enumerate(messages):
|
||||
if isinstance(message.content, ChatCompletionMessageText):
|
||||
message.content = message.content.text
|
||||
if isinstance(message.content, list):
|
||||
if len(message.content) != 1:
|
||||
logger.warning("Received malformed prompt")
|
||||
continue
|
||||
|
||||
# Verify we have required fields
|
||||
if "role" not in filtered_message:
|
||||
raise ValueError(f"Message missing 'role' field: {filtered_message}")
|
||||
if "content" not in filtered_message and "thinking" not in filtered_message:
|
||||
# If neither content nor thinking is present, skip this message
|
||||
message.content = message.content[0].text
|
||||
if message.content is None and message.thinking is None:
|
||||
continue
|
||||
|
||||
formatted_messages.append(filtered_message)
|
||||
|
||||
messages_dicts = formatted_messages
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
formatted_messages.append(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None}
|
||||
)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template( # type: ignore
|
||||
messages_dicts,
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
@@ -269,16 +312,23 @@ class NullKVCache(KVCache):
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model,
|
||||
max_kv_size: int | None = None,
|
||||
) -> list[KVCache | RotatingKVCache]:
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
if max_kv_size is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
else:
|
||||
logger.info("Using quantized KV cache")
|
||||
return [
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||
for _ in model.layers
|
||||
]
|
||||
else:
|
||||
logger.info(f"Using rotating KV cache with {max_kv_size=}")
|
||||
return [RotatingKVCache(max_size=max_kv_size) for _ in model.layers]
|
||||
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
||||
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
||||
|
||||
|
||||
def mlx_force_oom(size: int = 40000) -> None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import signal
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
@@ -8,14 +8,13 @@ import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.logging import logger
|
||||
import exo.routing.topics as topics
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
from exo.shared.constants import EXO_LOG
|
||||
from exo.shared.election import Election, ElectionResult
|
||||
from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.logging import logger, logger_cleanup, logger_setup
|
||||
from exo.shared.types.commands import KillCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
@@ -119,6 +118,7 @@ class Node:
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
if self._tg.cancel_scope.cancel_called:
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
@@ -208,7 +208,7 @@ class Node:
|
||||
|
||||
def main():
|
||||
args = Args.parse()
|
||||
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
|
||||
@@ -14,6 +14,7 @@ from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from loguru import logger
|
||||
|
||||
from exo.engines.mlx.constants import HIDE_THINKING
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
@@ -45,6 +46,7 @@ from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
@@ -171,9 +173,9 @@ class API:
|
||||
)
|
||||
|
||||
command = CreateInstance(
|
||||
command_id=CommandId(),
|
||||
model_meta=model_meta,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
sharding=payload.sharding,
|
||||
)
|
||||
await self._send(command)
|
||||
@@ -194,7 +196,6 @@ class API:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
command = DeleteInstance(
|
||||
command_id=CommandId(),
|
||||
instance_id=instance_id,
|
||||
)
|
||||
await self._send(command)
|
||||
@@ -212,17 +213,26 @@ class API:
|
||||
self._chat_completion_queues[command_id] = asyncio.Queue()
|
||||
|
||||
finished = False
|
||||
is_thinking = False
|
||||
while not finished:
|
||||
# TODO: how long should this timeout be?
|
||||
chunk = await asyncio.wait_for(
|
||||
self._chat_completion_queues[command_id].get(), timeout=600
|
||||
)
|
||||
assert isinstance(chunk, TokenChunk)
|
||||
# TODO: Do we want this?
|
||||
if HIDE_THINKING:
|
||||
if chunk.text == "<think>":
|
||||
chunk.text = "\n"
|
||||
if chunk.text == "</think>":
|
||||
chunk.text = "\n"
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if not HIDE_THINKING or not is_thinking:
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -244,31 +254,6 @@ class API:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
# Preprocess messages for GPT-OSS harmony format if needed
|
||||
# TODO: This is slop surely we get rid
|
||||
if "gpt-oss" in payload.model.lower():
|
||||
import re
|
||||
|
||||
for message in payload.messages:
|
||||
if message.content and "<|channel|>" in message.content:
|
||||
# Parse harmony format tags
|
||||
thinking_pattern = r"<\|channel\|>(.*?)(?=<\|message\|>|$)"
|
||||
content_pattern = r"<\|message\|>(.*?)(?=<\|end\|>|$)"
|
||||
|
||||
thinking_match = re.search(
|
||||
thinking_pattern, message.content, re.DOTALL
|
||||
)
|
||||
content_match = re.search(
|
||||
content_pattern, message.content, re.DOTALL
|
||||
)
|
||||
|
||||
if content_match:
|
||||
# Extract the actual content
|
||||
message.content = content_match.group(1).strip()
|
||||
if thinking_match:
|
||||
# Store thinking in the thinking field
|
||||
message.thinking = thinking_match.group(1).strip()
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
for instance in self.state.instances.values()
|
||||
@@ -279,7 +264,6 @@ class API:
|
||||
)
|
||||
|
||||
command = ChatCompletion(
|
||||
command_id=CommandId(),
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
@@ -325,9 +309,19 @@ class API:
|
||||
tg.start_soon(uvicorn_server.serve)
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._print_banner_when_ready, uvicorn_server)
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _print_banner_when_ready(self, uvicorn_server: uvicorn.Server):
|
||||
"""Wait for the uvicorn server to be ready, then print the startup banner."""
|
||||
# TODO: Is this the best condition to check for?
|
||||
# The point is this should log when exo is ready.
|
||||
while not uvicorn_server.started:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
print_startup_banner(self.port)
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -209,7 +209,7 @@ class Master:
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# TODO: SQL
|
||||
# TODO: SQL <- What does this mean?
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
@@ -41,13 +43,13 @@ def get_instance_placements_after_create(
|
||||
tb_only: bool = False,
|
||||
) -> dict[InstanceId, Instance]:
|
||||
all_nodes = list(topology.list_nodes())
|
||||
from loguru import logger
|
||||
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
logger.info(f"{cycles=}")
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = cycles + singleton_cycles
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
|
||||
)
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, command.model_meta.storage_size
|
||||
)
|
||||
|
||||
@@ -210,27 +210,19 @@ def get_mlx_ibv_devices_matrix(
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
# just for debugging for now...
|
||||
for connection_ip in _find_connection_ip(node_i, node_j, cycle_digraph):
|
||||
interface_name = _find_interface_name_for_ip(connection_ip, node_i)
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
)
|
||||
|
||||
matrix[i][j] = "rdma_en3" # TODO: hack, for now it's always en3
|
||||
continue
|
||||
|
||||
for connection_ip in _find_connection_ip(node_i, node_j, cycle_digraph):
|
||||
# Set the first valid rmda i -> j connection - if there are multiple, we set essentially randomly - this is fine, the connection doesn't appear to have to be bidirectional
|
||||
if (
|
||||
interface_name := _find_interface_name_for_ip(
|
||||
connection_ip,
|
||||
node_i,
|
||||
)
|
||||
) is not None:
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current ibv backend requires all-to-all rdma connections"
|
||||
)
|
||||
@@ -246,8 +238,9 @@ def _find_connection_ip(
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_j.node_id
|
||||
and connection.send_back_node_id == node_i.node_id
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
# TODO: Check if we need this.
|
||||
and connection.send_back_multiaddr is not None
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address
|
||||
@@ -260,13 +253,13 @@ def _find_interface_name_for_ip(
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
logger.info(
|
||||
f"Checking interface {interface.name} for IP {interface.ip_address} == {ip_address}: {interface.ip_address == ip_address}"
|
||||
)
|
||||
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
|
||||
continue
|
||||
logger.info(f" | {interface.name}: {interface.ip_address}")
|
||||
if interface.ip_address == ip_address:
|
||||
logger.info("Found")
|
||||
return f"rdma_{interface.name}"
|
||||
|
||||
return None
|
||||
|
||||
@@ -41,11 +41,15 @@ def create_node():
|
||||
@pytest.fixture
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
ip_counter = 1
|
||||
|
||||
def _create_connection(
|
||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
||||
) -> Connection:
|
||||
nonlocal port_counter
|
||||
nonlocal ip_counter
|
||||
# assign unique ips
|
||||
ip_counter += 1
|
||||
if send_back_port is None:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
@@ -53,7 +57,7 @@ def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.1/tcp/{send_back_port}"
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
get_instance_placements_after_create,
|
||||
@@ -356,10 +357,18 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
conn_b_c = create_connection(node_id_b, node_id_c)
|
||||
conn_c_a = create_connection(node_id_c, node_id_a)
|
||||
|
||||
conn_b_a = create_connection(node_id_b, node_id_a)
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
@@ -368,7 +377,12 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
type="rdma",
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
type="rdma",
|
||||
),
|
||||
ethernet_interface,
|
||||
@@ -381,9 +395,14 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
type="rdma",
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
type="rdma",
|
||||
),
|
||||
ethernet_interface,
|
||||
@@ -397,8 +416,13 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en5",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
type="rdma",
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
type="rdma",
|
||||
),
|
||||
ethernet_interface,
|
||||
@@ -412,6 +436,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
create_instance_command = CreateInstance(
|
||||
command_id=CommandId(),
|
||||
@@ -444,9 +471,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
idx_b = node_to_idx[node_id_b]
|
||||
idx_c = node_to_idx[node_id_c]
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
logger.info(matrix)
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
|
||||
assert ":" in instance.mlx_ibv_coordinator
|
||||
assert not instance.mlx_ibv_coordinator.startswith("169.254")
|
||||
|
||||
@@ -252,9 +252,5 @@ def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> Sta
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
topology.remove_connection(event.edge)
|
||||
if not topology.contains_connection(event.edge) and topology.contains_connection(
|
||||
event.edge.reverse()
|
||||
):
|
||||
topology.remove_connection(event.edge.reverse())
|
||||
# TODO: Clean up removing the reverse connection
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@@ -14,32 +14,32 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
"deepseek-v3-0324:4bit": ModelCard(
|
||||
short_id="deepseek-v3-0324:4bit",
|
||||
model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
name="DeepSeek V3 0324 (4-bit)",
|
||||
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
|
||||
pretty_name="DeepSeek V3 0324 (4-bit)",
|
||||
storage_size=Memory.from_kb(409706307),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
"deepseek-v3-0324": ModelCard(
|
||||
short_id="deepseek-v3-0324",
|
||||
model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
name="DeepSeek V3 0324 (8-bit)",
|
||||
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
|
||||
pretty_name="DeepSeek V3 0324 (8-bit)",
|
||||
storage_size=Memory.from_kb(754706307),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3-0324:4bit": ModelCard(
|
||||
# short_id="deepseek-v3-0324:4bit",
|
||||
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
# name="DeepSeek V3 0324 (4-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3-0324": ModelCard(
|
||||
# short_id="deepseek-v3-0324",
|
||||
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
# name="DeepSeek V3 0324 (8-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
"deepseek-v3.1": ModelCard(
|
||||
short_id="deepseek-v3.1",
|
||||
model_id="mlx-community/DeepSeek-V3.1-8bit",
|
||||
@@ -67,32 +67,32 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
),
|
||||
# deepseek r1
|
||||
"deepseek-r1-0528:4bit": ModelCard(
|
||||
short_id="deepseek-r1-0528:4bit",
|
||||
model_id="mlx-community/DeepSeek-R1-0528-4bit",
|
||||
name="DeepSeek-R1-0528 (4-bit)",
|
||||
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
|
||||
pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
storage_size=Memory.from_kb(409706307),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
"deepseek-r1-0528": ModelCard(
|
||||
short_id="deepseek-r1-0528",
|
||||
model_id="mlx-community/DeepSeek-R1-0528-8bit",
|
||||
name="DeepSeek-R1-0528 (8-bit)",
|
||||
description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
|
||||
pretty_name="DeepSeek R1 671B (8-bit)",
|
||||
storage_size=Memory.from_bytes(754998771712),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
# "deepseek-r1-0528:4bit": ModelCard(
|
||||
# short_id="deepseek-r1-0528:4bit",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
|
||||
# name="DeepSeek-R1-0528 (4-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
|
||||
# pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-r1-0528": ModelCard(
|
||||
# short_id="deepseek-r1-0528",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
|
||||
# name="DeepSeek-R1-0528 (8-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
|
||||
# pretty_name="DeepSeek R1 671B (8-bit)",
|
||||
# storage_size=Memory.from_bytes(754998771712),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
@@ -228,19 +228,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=32,
|
||||
),
|
||||
),
|
||||
"phi-3-mini:128k": ModelCard(
|
||||
short_id="phi-3-mini:128k",
|
||||
model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
name="Phi 3 Mini 128k",
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
pretty_name="Phi 3 Mini 128k",
|
||||
storage_size=Memory.from_kb(2099262),
|
||||
n_layers=32,
|
||||
),
|
||||
),
|
||||
# "phi-3-mini:128k": ModelCard(
|
||||
# short_id="phi-3-mini:128k",
|
||||
# model_id="mlx-community/Phi-3-mini-128k-instruct-4bit",
|
||||
# name="Phi 3 Mini 128k",
|
||||
# description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
# pretty_name="Phi 3 Mini 128k",
|
||||
# storage_size=Memory.from_kb(2099262),
|
||||
# n_layers=32,
|
||||
# ),
|
||||
# ),
|
||||
# qwen3
|
||||
"qwen3-0.6b": ModelCard(
|
||||
short_id="qwen3-0.6b",
|
||||
@@ -268,19 +268,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=48,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b": ModelCard(
|
||||
short_id="qwen3-235b-a22b",
|
||||
model_id="mlx-community/Qwen3-235B-A22B-4bit",
|
||||
name="Qwen3 235B, Active 22B (4-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-4bit"),
|
||||
pretty_name="Qwen3 235B, Active 22B (4-bit)",
|
||||
storage_size=Memory.from_kb(123207680),
|
||||
n_layers=94,
|
||||
),
|
||||
),
|
||||
# "qwen3-235b-a22b": ModelCard(
|
||||
# short_id="qwen3-235b-a22b",
|
||||
# model_id="mlx-community/Qwen3-235B-A22B-4bit",
|
||||
# name="Qwen3 235B, Active 22B (4-bit)",
|
||||
# description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/Qwen3-235B-A22B-4bit"),
|
||||
# pretty_name="Qwen3 235B, Active 22B (4-bit)",
|
||||
# storage_size=Memory.from_kb(123207680),
|
||||
# n_layers=94,
|
||||
# ),
|
||||
# ),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
short_id="qwen3-235b-a22b-8bit",
|
||||
model_id="mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit",
|
||||
@@ -308,31 +308,31 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=40,
|
||||
),
|
||||
),
|
||||
"granite-3.3-8b": ModelCard(
|
||||
short_id="granite-3.3-8b",
|
||||
model_id="mlx-community/granite-3.3-8b-instruct-fp16",
|
||||
name="Granite 3.3 8B",
|
||||
description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
pretty_name="Granite 3.3 8B",
|
||||
storage_size=Memory.from_kb(15958720),
|
||||
n_layers=40,
|
||||
),
|
||||
),
|
||||
# "granite-3.3-8b": ModelCard(
|
||||
# short_id="granite-3.3-8b",
|
||||
# model_id="mlx-community/granite-3.3-8b-instruct-fp16",
|
||||
# name="Granite 3.3 8B",
|
||||
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# pretty_name="Granite 3.3 8B",
|
||||
# storage_size=Memory.from_kb(15958720),
|
||||
# n_layers=40,
|
||||
# ),
|
||||
# ),
|
||||
# smol-lm
|
||||
"smol-lm-135m": ModelCard(
|
||||
short_id="smol-lm-135m",
|
||||
model_id="mlx-community/SmolLM-135M-4bit",
|
||||
name="Smol LM 135M",
|
||||
description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
|
||||
pretty_name="Smol LM 135M",
|
||||
storage_size=Memory.from_kb(73940),
|
||||
n_layers=30,
|
||||
),
|
||||
),
|
||||
# "smol-lm-135m": ModelCard(
|
||||
# short_id="smol-lm-135m",
|
||||
# model_id="mlx-community/SmolLM-135M-4bit",
|
||||
# name="Smol LM 135M",
|
||||
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
|
||||
# pretty_name="Smol LM 135M",
|
||||
# storage_size=Memory.from_kb(73940),
|
||||
# n_layers=30,
|
||||
# ),
|
||||
# ),
|
||||
}
|
||||
|
||||
@@ -70,6 +70,9 @@ class Topology:
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
|
||||
|
||||
@@ -24,13 +24,20 @@ class ModelListModel(BaseModel):
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
object: Literal["list"] = "list"
|
||||
data: list[ModelListModel]
|
||||
|
||||
|
||||
class ChatCompletionMessageText(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
|
||||
content: str | None = None
|
||||
content: (
|
||||
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
|
||||
) = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
@@ -55,20 +62,6 @@ class Logprobs(BaseModel):
|
||||
content: list[LogprobsContentItem] | None = None
|
||||
|
||||
|
||||
class StreamingChoiceResponse(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
logprobs: Logprobs | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
logprobs: Logprobs | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class PromptTokensDetails(BaseModel):
|
||||
cached_tokens: int = 0
|
||||
audio_tokens: int = 0
|
||||
@@ -89,6 +82,21 @@ class Usage(BaseModel):
|
||||
completion_tokens_details: CompletionTokensDetails | None = None
|
||||
|
||||
|
||||
class StreamingChoiceResponse(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
logprobs: Logprobs | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
logprobs: Logprobs | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
@@ -125,8 +133,8 @@ class CreateInstanceTaskParams(BaseModel):
|
||||
# TODO: in future the user could specify a specific Instance, not just a model_id
|
||||
model_id: str
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
# TODO: fix
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
|
||||
class DeleteInstanceTaskParams(BaseModel):
|
||||
|
||||
@@ -29,6 +29,7 @@ class CreateInstance(BaseCommand):
|
||||
model_meta: ModelMetadata
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
|
||||
|
||||
class DeleteInstance(BaseCommand):
|
||||
|
||||
@@ -12,20 +12,17 @@ class NodeInfo(CamelCaseModel):
|
||||
class Connection(CamelCaseModel):
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr | None
|
||||
send_back_multiaddr: Multiaddr
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if self.send_back_multiaddr:
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
else:
|
||||
return hash((self.local_node_id, self.send_back_node_id))
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Connection):
|
||||
@@ -37,13 +34,4 @@ class Connection(CamelCaseModel):
|
||||
)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return self.send_back_multiaddr is not None and str(
|
||||
self.send_back_multiaddr.ipv4_address
|
||||
).startswith("169.254")
|
||||
|
||||
def reverse(self) -> "Connection":
|
||||
return Connection(
|
||||
local_node_id=self.send_back_node_id,
|
||||
send_back_node_id=self.local_node_id,
|
||||
send_back_multiaddr=None,
|
||||
)
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("200.0")
|
||||
|
||||
@@ -41,6 +41,7 @@ class BoundInstance(CamelCaseModel):
|
||||
instance: Instance
|
||||
bound_runner_id: RunnerId
|
||||
|
||||
@property
|
||||
def bound_shard(self) -> ShardMetadata:
|
||||
shard = self.instance.shard(self.bound_runner_id)
|
||||
assert shard is not None
|
||||
|
||||
@@ -48,6 +48,7 @@ class RunnerRunning(BaseRunnerStatus):
|
||||
class RunnerShutdown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerFailed(BaseRunnerStatus):
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
34
src/exo/utils/banner.py
Normal file
34
src/exo/utils/banner.py
Normal file
@@ -0,0 +1,34 @@
|
||||
def print_startup_banner(port: int) -> None:
|
||||
"""Print a prominent startup banner with API endpoint information."""
|
||||
banner = """
|
||||
╔═══════════════════════════════════════════════════════════════════════╗
|
||||
║ ║
|
||||
║ ███████╗██╗ ██╗ ██████╗ ║
|
||||
║ ██╔════╝╚██╗██╔╝██╔═══██╗ ║
|
||||
║ █████╗ ╚███╔╝ ██║ ██║ ║
|
||||
║ ██╔══╝ ██╔██╗ ██║ ██║ ║
|
||||
║ ███████╗██╔╝ ██╗╚██████╔╝ ║
|
||||
║ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ║
|
||||
║ ║
|
||||
║ Distributed AI Inference Cluster ║
|
||||
║ ║
|
||||
╚═══════════════════════════════════════════════════════════════════════╝
|
||||
"""
|
||||
|
||||
dashboard_url = f"http://localhost:{port}"
|
||||
|
||||
api_info = f"""
|
||||
╔═══════════════════════════════════════════════════════════════════════╗
|
||||
║ ║
|
||||
║ 🌐 Dashboard & API Ready ║
|
||||
║ ║
|
||||
║ {dashboard_url}{" " * (69 - len(dashboard_url))}║
|
||||
║ ║
|
||||
║ Click the URL above to open the dashboard in your browser ║
|
||||
║ ║
|
||||
╚═══════════════════════════════════════════════════════════════════════╝
|
||||
"""
|
||||
|
||||
print(banner)
|
||||
print(api_info)
|
||||
print()
|
||||
@@ -139,7 +139,9 @@ class MpSender[T]:
|
||||
# == unique to Mp channels ==
|
||||
def join(self) -> None:
|
||||
"""Ensure any queued messages are resolved before continuing"""
|
||||
assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
|
||||
assert self._state.closed.is_set(), (
|
||||
"Mp channels must be closed before being joined"
|
||||
)
|
||||
self._state.buffer.join_thread()
|
||||
|
||||
# == context manager support ==
|
||||
@@ -209,7 +211,9 @@ class MpReceiver[T]:
|
||||
# == unique to Mp channels ==
|
||||
def join(self) -> None:
|
||||
"""Block until all enqueued messages are drained off our side of the buffer"""
|
||||
assert self._state.closed.is_set(), "Mp channels must be closed before being joined"
|
||||
assert self._state.closed.is_set(), (
|
||||
"Mp channels must be closed before being joined"
|
||||
)
|
||||
self._state.buffer.join_thread()
|
||||
|
||||
# == iterator support ==
|
||||
|
||||
@@ -100,12 +100,12 @@ def _model_needs_download(
|
||||
for runner in runners.values():
|
||||
if (
|
||||
isinstance(runner.status, RunnerWaitingForModel)
|
||||
and runner.bound_instance.bound_shard() not in download_status
|
||||
and runner.bound_instance.bound_shard not in download_status
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
instance_id=runner.bound_instance.instance.instance_id,
|
||||
shard_metadata=runner.bound_instance.bound_shard(),
|
||||
shard_metadata=runner.bound_instance.bound_shard,
|
||||
)
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ def _ready_to_warmup(
|
||||
)
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
)
|
||||
and runner.bound_instance.bound_shard().device_rank != 0
|
||||
and runner.bound_instance.bound_shard.device_rank != 0
|
||||
)
|
||||
or (
|
||||
all(
|
||||
@@ -170,7 +170,7 @@ def _ready_to_warmup(
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner.bound_instance.bound_runner_id
|
||||
)
|
||||
and runner.bound_instance.bound_shard().device_rank == 0
|
||||
and runner.bound_instance.bound_shard.device_rank == 0
|
||||
)
|
||||
):
|
||||
return StartWarmup(instance_id=runner.bound_instance.instance.instance_id)
|
||||
|
||||
@@ -6,6 +6,9 @@ from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.engines.mlx import Model
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
make_kv_cache,
|
||||
@@ -70,6 +73,8 @@ def warmup_inference(
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=65536,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info("Generated warmup token: " + str(_r.text))
|
||||
tokens_generated += 1
|
||||
@@ -94,19 +99,19 @@ def mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
cache = make_kv_cache(
|
||||
model=model,
|
||||
)
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
max_tokens = task.max_tokens or 1000
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prompt_cache=caches,
|
||||
prefill_step_size=65536,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
if out.finish_reason is not None and out.finish_reason not in get_args(
|
||||
|
||||
@@ -22,7 +22,7 @@ from exo.shared.types.tasks import (
|
||||
)
|
||||
from exo.shared.types.worker.commands_runner import (
|
||||
GenerationResponse,
|
||||
TokenizedResponse,
|
||||
# TokenizedResponse,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
@@ -31,12 +31,12 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
RunnerShutdown
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender, ClosedResourceError
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
from exo.worker.runner.generate import mlx_generate, warmup_inference
|
||||
|
||||
@@ -49,7 +49,7 @@ def main(
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
bound_instance.bound_runner_id,
|
||||
bound_instance.bound_shard(),
|
||||
bound_instance.bound_shard,
|
||||
)
|
||||
try:
|
||||
logger.info("hello from the runner")
|
||||
@@ -115,6 +115,7 @@ def main(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
@@ -185,9 +186,9 @@ def main(
|
||||
),
|
||||
)
|
||||
)
|
||||
case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
logger.info("Finished tokenizing?")
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
# logger.info("Finished tokenizing?")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -212,9 +213,7 @@ def main(
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerShutdown()
|
||||
)
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
except ClosedResourceError:
|
||||
logger.warning("runner communication closed unexpectedly")
|
||||
|
||||
@@ -67,7 +67,7 @@ class RunnerSupervisor:
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
shard_metadata = bound_instance.bound_shard()
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
|
||||
self = cls(
|
||||
bound_instance=bound_instance,
|
||||
@@ -109,12 +109,13 @@ class RunnerSupervisor:
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGKILL. System resources may have leaked")
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
event = anyio.Event()
|
||||
@@ -126,7 +127,6 @@ class RunnerSupervisor:
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
@@ -140,7 +140,6 @@ class RunnerSupervisor:
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||
@@ -152,7 +151,7 @@ class RunnerSupervisor:
|
||||
await to_thread.run_sync(self.runner_process.join, 1)
|
||||
rc = self.runner_process.exitcode
|
||||
if rc == 0:
|
||||
#
|
||||
#
|
||||
return
|
||||
|
||||
if isinstance(rc, int) and rc < 0:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
from exo.worker.common import AssignedRunner
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.state import State
|
||||
@@ -27,7 +28,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunningRunnerStatus,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.common import AssignedRunner
|
||||
from exo.worker.main import Worker
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.tests.constants import (
|
||||
|
||||
@@ -13,9 +13,9 @@ QUERY="$*"
|
||||
curl -sN -X POST "http://$HOST:8000/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"mlx-community/Llama-3.3-70B-Instruct-8bit\",
|
||||
\"model\": \"mlx-community/Kimi-K2-Thinking\",
|
||||
\"stream\": true,
|
||||
\"messages\": [{ \"role\": \"user\", \"content\": \"$QUERY\" }]
|
||||
\"messages\": [{ \"role\": \"user\", \"content\": \"$QUERY\"}]
|
||||
}" |
|
||||
grep --line-buffered '^data:' |
|
||||
grep --line-buffered -v 'data: \[DONE\]' |
|
||||
|
||||
Reference in New Issue
Block a user