Compare commits

..

12 Commits

Author SHA1 Message Date
Alex Cheema
771a94d944 debug: log dense vs MoE layer counts in DeepSeekShardingStrategy
This will show how many layers use shard_linear (dense) vs
shard_inplace (MoE) for kimi-k2 and similar models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:29:53 +00:00
Alex Cheema
0c266151ca fix: remove model.shard() bypass - always use custom strategies
The model.shard() call was bypassing our custom sharding strategies
for models that have a built-in shard method. This could be causing
the inconsistent behavior between different models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:08:40 +00:00
Alex Cheema
556f5a0f6d debug: add logging to identify which sharding strategy is used
Log model type, whether it has built-in shard method, and which
strategy is selected. This will help identify patterns between
working and broken models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:04:39 +00:00
Alex Cheema
1d0b121457 revert: restore MLX_METAL_FAST_SYNCH to original location
Revert to setting MLX_METAL_FAST_SYNCH in bootstrap.py before model
loading. Setting it after loading doesn't work properly.

The hang issue with certain models (gpt-oss-20b) + jaccl + fast_synch
needs further investigation into why those specific models trigger
the fence polling deadlock.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:44:40 +00:00
Alex Cheema
f036add84f fix: defer MLX_METAL_FAST_SYNCH until after model loading
MLX_METAL_FAST_SYNCH=1 causes hangs during lazy weight evaluation
with certain models (e.g., gpt-oss-20b) on the jaccl backend. The
fast sync mode appears to conflict with lazy array materialization.

Fix by setting MLX_METAL_FAST_SYNCH=1 only AFTER model loading
completes. This preserves the performance benefit during inference
while avoiding the loading hang.

Also cleaned up debug logging added during investigation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:35:07 +00:00
Alex Cheema
d63c8c86a8 fix: use tree_flatten for nested parameter dict
model.parameters() returns nested dicts, not flat. Use
mx.utils.tree_flatten to get flat list of (name, array) tuples.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:32:55 +00:00
Alex Cheema
80608eaf64 debug: more granular logging to find exact hang location
Log before/after each step: model.parameters(), dict conversion,
and each individual param eval to isolate where hang occurs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:32:32 +00:00
Alex Cheema
fc32199653 debug: eval parameters one-by-one to identify hang location
Iterate through model.parameters() and eval each one individually
with logging to pinpoint exactly which parameter causes the hang.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:30:32 +00:00
Alex Cheema
028e29a6d8 test: try barrier-only fix without preloading all weights
Remove the early mx.eval that loads entire model - just keep barrier
to sync nodes before sharding. This is important because preloading
the entire model on each node would OOM for large models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:20:02 +00:00
Alex Cheema
3941855ad6 debug: add logging around shard_linear and shard_inplace calls
Adding logging to understand where distributed communication happens
during tensor parallelism setup.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:11:39 +00:00
Alex Cheema
1933b224c9 fix: materialize lazy weights before distributed sharding
The jaccl backend deadlocks when mx.eval() is called on lazy weights
that have been wrapped with distributed sharding operations. The issue
is that lazy weight loading (downloading from HF) and distributed
communication were happening simultaneously.

Fix by:
1. Calling mx.eval(model.parameters()) BEFORE tensor_auto_parallel
2. Adding a barrier to ensure all nodes have weights before sharding
3. Then applying sharding to already-materialized weights

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:10:49 +00:00
Alex Cheema
737d97a2d4 Add detailed logging for jaccl/tensor parallel model loading
Add logging at critical points to debug MlxJacclInstance stuck in
RunnerLoading state:

- Before/after mx.distributed.init(backend="jaccl")
- Before/after shard_and_load, load_model
- Before/after tensor_auto_parallel with sharding strategy info
- Progress logs during GptOss layer sharding
- Before/after mx.eval(model.parameters()) and mx.eval(model)
- Before/after mx_barrier(group) sync

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 20:58:25 +00:00
17 changed files with 1078 additions and 759 deletions

View File

@@ -91,6 +91,51 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## API Reference
The API is served at `http://localhost:52415` by default. Key files:
- `docs/api.md`: Full API documentation
- `src/exo/master/api.py`: FastAPI implementation
- `src/exo/shared/types/api.py`: Request/response Pydantic models
### Key Endpoints
```
GET /node_id # Current master node ID
GET /state # Full cluster state (topology, instances, downloads, etc.)
GET /events # Event log for debugging
POST /instance # Create model instance
GET /instance/{id} # Get instance details
DELETE /instance/{id} # Delete instance
GET /instance/previews # Preview placements for a model
GET /instance/placement # Compute placement without creating
GET /models # List available models
GET /v1/models # OpenAI-compatible model list
POST /v1/chat/completions # OpenAI-compatible chat completions (streaming/non-streaming)
POST /bench/chat/completions # Chat completions with performance stats
```
### Useful curl Commands
```bash
# Check cluster state
curl -s http://localhost:52415/state | python3 -m json.tool
# List models
curl -s http://localhost:52415/models | python3 -m json.tool
# Preview placements for a model
curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | python3 -m json.tool
# Chat completion
curl -X POST http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "llama-3.2-1b", "messages": [{"role": "user", "content": "Hello"}]}'
```
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

View File

@@ -56,11 +56,6 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
let result = await self.performCheck()
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:

View File

@@ -241,9 +241,6 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn

View File

@@ -60,39 +60,12 @@
return models;
});
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
// Auto-select the first available model if none is selected
$effect(() => {
const models = availableModels();
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -400,8 +400,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -761,10 +763,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -773,24 +771,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);

View File

@@ -1,5 +1,3 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -6,6 +6,8 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"pydantic>=2.11.7",
"fastapi>=0.116.1",
"filelock>=3.18.0",
@@ -21,7 +23,6 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
]
[project.scripts]

View File

@@ -13,6 +13,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -61,6 +67,8 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -373,8 +381,35 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -382,10 +417,16 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -401,11 +442,11 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -417,7 +458,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -425,7 +466,7 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -454,7 +495,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -462,7 +503,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -503,6 +544,8 @@ class API:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -519,16 +562,17 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -545,7 +589,10 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
return response
def _calculate_total_available_memory(self) -> Memory:

View File

@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -425,15 +425,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,

View File

@@ -7,13 +7,13 @@ import time
import traceback
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal, cast
from typing import Callable, Literal
from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
import certifi
import httpx
from loguru import logger
from pydantic import (
BaseModel,
@@ -207,22 +207,23 @@ async def _fetch_file_list(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
response = await session.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Failed to fetch file list: {response.status_code}")
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
async def get_download_headers() -> dict[str, str]:
@@ -230,25 +231,31 @@ async def get_download_headers() -> dict[str, str]:
def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long",
) -> httpx.AsyncClient:
) -> aiohttp.ClientSession:
if timeout_profile == "short":
total_timeout = 30
connect_timeout = 10
read_timeout = 30
sock_read_timeout = 30
sock_connect_timeout = 10
else:
total_timeout = 1800
connect_timeout = 60
read_timeout = 1800
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
return httpx.AsyncClient(
verify=ssl_context,
timeout=httpx.Timeout(
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
read=read_timeout,
write=total_timeout,
sock_read=sock_read_timeout,
sock_connect=sock_connect_timeout,
),
)
@@ -275,25 +282,23 @@ async def file_meta(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.head(url, headers=headers) as r,
):
r = await session.head(url, headers=headers)
if r.status_code == 307:
if r.status == 307:
# On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("x-linked-etag")
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag)
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = cast(str | None, r.headers.get("location"))
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
content_length = cast(
str | None,
r.headers.get("x-linked-size") or r.headers.get("content-length"),
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
content_length = 0 if content_length is None else int(content_length)
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
etag = trim_etag(etag)
@@ -352,17 +357,17 @@ async def _download_file(
n_read = resume_byte_pos or 0
async with (
create_http_session(timeout_profile="long") as session,
session.get(url, headers=headers) as r,
):
r = await session.get(url, headers=headers)
if r.status_code == 404:
if r.status == 404:
raise FileNotFoundError(f"File not found: {url}")
assert r.status_code in [200, 206], (
f"Failed to download {path} from {url}: {r.status_code}"
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
)
async with aiofiles.open(
partial_path, "ab" if resume_byte_pos else "wb"
) as f:
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
while chunk := await r.content.read(8 * 1024 * 1024):
n_read = n_read + (await f.write(chunk))
on_progress(n_read, length, False)

View File

@@ -228,15 +228,10 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
except (AttributeError, TypeError, NameError):
pass
logger.info(f"tensor_auto_parallel: model type = {type(model).__name__}")
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
logger.info("Using LlamaShardingStrategy")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -245,7 +240,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
logger.info("Using DeepSeekShardingStrategy")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -254,6 +249,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, MiniMaxModel):
logger.info("Using MiniMaxShardingStrategy")
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
@@ -262,6 +258,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
logger.info("Using QwenShardingStrategy")
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -270,6 +267,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
logger.info("Using GptOssShardingStrategy for tensor parallelism")
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
@@ -352,6 +350,8 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
dense_count = 0
moe_count = 0
for layer in model.layers:
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
@@ -370,6 +370,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
dense_count += 1
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
@@ -377,6 +378,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
moe_count += 1
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
@@ -386,6 +388,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
logger.info(f"DeepSeekShardingStrategy: {dense_count} dense layers (shard_linear), {moe_count} MoE layers (shard_inplace)")
return model
@@ -481,7 +484,6 @@ class ShardedQwenMoE(CustomMlxLayer):
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)

View File

@@ -20,7 +20,6 @@ except ImportError:
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -163,7 +162,9 @@ def mlx_distributed_init(
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
logger.info(f"rank {rank} BEFORE mx.distributed.init(backend='jaccl')")
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"rank {rank} AFTER mx.distributed.init - group created")
logger.info(f"Rank {rank} mlx distributed initialization complete")
@@ -200,10 +201,12 @@ def load_mlx_items(
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
logger.info("Starting distributed shard_and_load")
start_time = time.perf_counter()
logger.info(f"BEFORE shard_and_load for model {bound_instance.bound_shard.model_meta.model_id}")
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
logger.info(f"AFTER shard_and_load completed")
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
@@ -218,8 +221,10 @@ def shard_and_load(
group: Group,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
logger.info(f"shard_and_load: model_path={model_path}")
logger.info("BEFORE load_model (lazy=True)")
model, _ = load_model(model_path, lazy=True, strict=False)
logger.info("AFTER load_model")
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
@@ -253,8 +258,6 @@ def shard_and_load(
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# TODO: Do we need this?
mx.eval(model)
logger.debug("SHARDED")
@@ -366,8 +369,6 @@ def apply_chat_template(
tools=chat_task_data.tools,
)
logger.info(prompt)
return prompt
@@ -399,11 +400,6 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -1,15 +1,6 @@
import time
from collections.abc import Generator
from functools import cache
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
@@ -162,19 +153,11 @@ def main(
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
for response in mlx_generator:
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
@@ -224,43 +207,6 @@ def main(
break
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -1,63 +1,60 @@
import anyio
import httpx
from anyio import create_task_group
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
REACHABILITY_ATTEMPTS = 3
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
body = response.read().decode("utf-8").strip()
remote_node_id = NodeId(body)
break
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
except (
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
) as e:
last_error = e
await anyio.sleep(1)
return NodeId(body) or None
except OSError:
return None
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()
else:
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
else:
logger.warning(
f"malformed response from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
@@ -77,33 +74,18 @@ async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
if node.node_id == self_node_id:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
client,
)
return reachable

1294
uv.lock generated
View File

File diff suppressed because it is too large Load Diff