Compare commits

..

18 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
8d6e52bdb5 Try wrong shardings 2026-01-16 19:39:44 +00:00
Ryuichi Leo Takashige
a412ec6d04 Comment out custom moe layer 2026-01-16 19:33:22 +00:00
Ryuichi Leo Takashige
8d00c6ad44 Comment out sharding group only 2026-01-16 19:28:31 +00:00
Ryuichi Leo Takashige
9bf1bb3025 remove moe sharding 2026-01-16 19:24:06 +00:00
Ryuichi Leo Takashige
c613df4d8d formatting 2026-01-16 18:48:57 +00:00
Ryuichi Leo Takashige
f5d1532245 revert api change 2026-01-16 18:46:54 +00:00
Ryuichi Leo Takashige
659fbdf7ea Handle model timeouts
- Add eval with a timeout.
- Add fast synch flag

Timeout mx barrier too
2026-01-16 18:40:48 +00:00
Ryuichi Leo Takashige
313b24fe04 Handle broken resource error gracefully
# Conflicts:
#	src/exo/master/api.py
2026-01-16 18:32:07 +00:00
Ryuichi Leo Takashige
b5bd0ffc98 Return error responses for Chat Completions
- Error chunks
- Use error handling in exo_bench.py
2026-01-16 18:30:55 +00:00
Evan
83c5285a80 reduce logs
previous commits logs were too verbose, this tones them down a bit
2026-01-16 14:05:47 +00:00
Evan Quiney
39ee2bf7bd switch from synchronous threaded pinging to an async implementation (#1170)
still seeing churn in our networking - lets properly rate limit it

## changes

added an httpx client with max connections with a persistent AsyncClient

## testing

deployed on cluster, discovery VASTLY more stable (the only deleted
edges were those discovered by mdns)
2026-01-16 13:20:03 +00:00
Sami Khan
991adfbd6f fix local network warning (#1136)
## Motivation

Local network warning banner was showing on fresh install even though
mDNS was working. The check would fail before the user had a chance to
grant permission via the macOS prompt.

## Changes

- Added `hasWorkedBefore` flag persisted in UserDefaults
- Only show warning if permission previously worked but now doesn't

## Why It Works

On fresh install, the check may fail (no permission yet), but
`hasWorkedBefore` is false so no warning shows. Once the user grants
permission and a check succeeds, we record it. Future failures (zombie
permission after restart) will show the warning since `hasWorkedBefore`
is now true.

## Test Plan

### Manual Testing
Run locally

### Automated Testing
N/A
2026-01-16 13:10:50 +00:00
rltakashige
4b3de6b984 Fix exo bench for transformers 5.x (#1168)
## Motivation
Prompt Sizer was broken as transformers 5.x tokenizers create
BatchEncodings which are essentially a dictionary of {input_ids: []}
instead of the list of input ids.

## Test Plan

### Manual Testing
Tested that exo bench runs as expected.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 12:39:22 +00:00
Evan
c8de3b90ea quiet rust logs
rust logs were too verbose - now only warnings propagate to python

entirely happy not to merge this and to clean up rust logging instead,
but this felt saner right now
2026-01-16 12:34:28 +00:00
Sami Khan
6e6567a802 resolve issue #1070 (#1076)
## Motivation

https://github.com/exo-explore/exo/issues/1070

## Changes

Added check in ChatForm.svelte to reset selectedChatModel when it no
longer matches any running instance.

## Why It Works

The $effect now detects when the selected model is stale (not in
availableModels()) and resets to the first available model.

## Test Plan

### Manual Testing

1. Create instance of Model A → Delete it → Create instance of Model B →
Chat
2. Verify request goes to Model B (not Model A)

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
2026-01-15 20:00:41 +00:00
rltakashige
a735dad667 Parse GPT OSS in runner (#1160)
## Motivation

Simplification of API + moving model specific code to the runner

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Test Plan

### Manual Testing
Tested that GPT OSS outputs are parsed correctly on the dashboard.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-15 19:53:55 +00:00
rltakashige
aaf4e36bc3 FIX GPT OSS (#1165)
## Motivation

Adds several unmerged fixes for GPT OSS.
Also adds GPT OSS 20B MXFP4 Q8 instead of Q4 for numerical stability (as
this is unstable for MLX LM too)
<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->


## Test Plan

### Manual Testing
Manually tested. No further gibberish responses.

### Automated Testing
Ran EXO Bench - pipeline, tensor and single node work on both 20B and
120B models
2026-01-15 19:20:17 +00:00
Evan Quiney
3e623ccf0d up http timeout to 3 seconds and retry on BadStatusLine (#1164)
we're seeing a lot of network churn - perhaps this is a connection
timing out issue? lets also re-try after a second

## testing
none yet

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 18:15:12 +00:00
57 changed files with 1221 additions and 744 deletions

View File

@@ -56,6 +56,11 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,30 +35,43 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
lastConnectionState = "connecting"
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
checkTask = Task { [weak self] in
guard let self else { return }
let result = await self.performCheck()
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
connection?.cancel()
connection = nil
@@ -84,22 +97,7 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
conn.stateUpdateHandler = { state in
switch state {
case .ready:
resumeOnce(.working)
@@ -108,6 +106,7 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -127,7 +126,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: 3_000_000_000)
try? await Task.sleep(nanoseconds: timeout)
let state = conn.state
switch state {
case .ready:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
@@ -26,7 +27,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -104,22 +105,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
@@ -241,6 +266,9 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn
@@ -296,6 +324,12 @@ def main() -> int:
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
@@ -317,7 +351,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -396,7 +430,7 @@ def main() -> int:
):
continue
if 0 < n <= args.max_nodes:
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
if not selected:
@@ -438,7 +472,13 @@ def main() -> int:
)
client.request_json("POST", "/instance", body={"instance": instance})
wait_for_instance_ready(client, instance_id)
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)

View File

@@ -60,12 +60,39 @@
return models;
});
// Auto-select the first available model if none is selected
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
$effect(() => {
const models = availableModels();
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -400,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -763,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -771,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);

View File

@@ -1,3 +1,5 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -23,7 +23,7 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"tomlkit>=0.14.0",
"httpx>=0.28.1",
]
[project.scripts]

View File

@@ -1,15 +0,0 @@
short_id = "deepseek-v3.1-4bit"
model_id = "mlx-community/DeepSeek-V3.1-4bit"
name = "DeepSeek V3.1 (4-bit)"
description = "DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/DeepSeek-V3.1-4bit"
pretty_name = "DeepSeek V3.1 (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 405874409472

View File

@@ -1,15 +0,0 @@
short_id = "deepseek-v3.1-8bit"
model_id = "mlx-community/DeepSeek-V3.1-8bit"
name = "DeepSeek V3.1 (8-bit)"
description = "DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/DeepSeek-V3.1-8bit"
pretty_name = "DeepSeek V3.1 (8-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 765577920512

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.5-air-8bit"
model_id = "mlx-community/GLM-4.5-Air-8bit"
name = "GLM 4.5 Air 8bit"
description = "GLM 4.5 Air 8bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.5-Air-8bit"
pretty_name = "GLM 4.5 Air 8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
[metadata.storage_size]
in_bytes = 122406567936

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.5-air-bf16"
model_id = "mlx-community/GLM-4.5-Air-bf16"
name = "GLM 4.5 Air bf16"
description = "GLM 4.5 Air bf16"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.5-Air-bf16"
pretty_name = "GLM 4.5 Air bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 229780750336

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.7-4bit"
model_id = "mlx-community/GLM-4.7-4bit"
name = "GLM 4.7 4bit"
description = "GLM 4.7 4bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-4bit"
pretty_name = "GLM 4.7 4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 198556925568

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.7-6bit"
model_id = "mlx-community/GLM-4.7-6bit"
name = "GLM 4.7 6bit"
description = "GLM 4.7 6bit"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-6bit"
pretty_name = "GLM 4.7 6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 286737579648

View File

@@ -1,15 +0,0 @@
short_id = "glm-4.7-8bit-gs32"
model_id = "mlx-community/GLM-4.7-8bit-gs32"
name = "GLM 4.7 8bit (gs32)"
description = "GLM 4.7 8bit (gs32)"
tags = []
[metadata]
model_id = "mlx-community/GLM-4.7-8bit-gs32"
pretty_name = "GLM 4.7 8bit (gs32)"
n_layers = 91
hidden_size = 5120
supports_tensor = true
[metadata.storage_size]
in_bytes = 396963397248

View File

@@ -1,15 +0,0 @@
short_id = "gpt-oss-120b-MXFP4-Q8"
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
name = "GPT-OSS 120B (MXFP4-Q8, MLX)"
description = "OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon."
tags = []
[metadata]
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
pretty_name = "GPT-OSS 120B (MXFP4-Q8, MLX)"
n_layers = 36
hidden_size = 2880
supports_tensor = true
[metadata.storage_size]
in_bytes = 70652212224

View File

@@ -1,15 +0,0 @@
short_id = "gpt-oss-20b-4bit"
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q4"
name = "GPT-OSS 20B (MXFP4-Q4, MLX)"
description = "OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization."
tags = []
[metadata]
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q4"
pretty_name = "GPT-OSS 20B (MXFP4-Q4, MLX)"
n_layers = 24
hidden_size = 2880
supports_tensor = true
[metadata.storage_size]
in_bytes = 12025908224

View File

@@ -1,15 +0,0 @@
short_id = "kimi-k2-instruct-4bit"
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
name = "Kimi K2 Instruct (4-bit)"
description = "Kimi K2 is a large language model trained on the Kimi K2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
pretty_name = "Kimi K2 Instruct (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 620622774272

View File

@@ -1,15 +0,0 @@
short_id = "kimi-k2-thinking"
model_id = "mlx-community/Kimi-K2-Thinking"
name = "Kimi K2 Thinking (4-bit)"
description = "Kimi K2 Thinking is the latest, most capable version of open-source thinking model."
tags = []
[metadata]
model_id = "mlx-community/Kimi-K2-Thinking"
pretty_name = "Kimi K2 Thinking (4-bit)"
n_layers = 61
hidden_size = 7168
supports_tensor = true
[metadata.storage_size]
in_bytes = 706522120192

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.1-70b"
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
name = "Llama 3.1 70B (4-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
pretty_name = "Llama 3.1 70B (4-bit)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 40652242944

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.1-8b-8bit"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
name = "Llama 3.1 8B (8-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
pretty_name = "Llama 3.1 8B (8-bit)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 8954839040

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.1-8b-bf16"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
name = "Llama 3.1 8B (BF16)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
pretty_name = "Llama 3.1 8B (BF16)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 16882073600

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.1-8b"
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
name = "Llama 3.1 8B (4-bit)"
description = "Llama 3.1 is a large language model trained on the Llama 3.1 dataset."
tags = []
[metadata]
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
pretty_name = "Llama 3.1 8B (4-bit)"
n_layers = 32
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 4637851648

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.2-1b"
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
name = "Llama 3.2 1B (4-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
pretty_name = "Llama 3.2 1B (4-bit)"
n_layers = 16
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 729808896

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.2-3b-8bit"
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
name = "Llama 3.2 3B (8-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
pretty_name = "Llama 3.2 3B (8-bit)"
n_layers = 28
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 3501195264

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.2-3b"
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
name = "Llama 3.2 3B (4-bit)"
description = "Llama 3.2 is a large language model trained on the Llama 3.2 dataset."
tags = []
[metadata]
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
pretty_name = "Llama 3.2 3B (4-bit)"
n_layers = 28
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 1863319552

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.3-70b-8bit"
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
name = "Llama 3.3 70B (8-bit)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
pretty_name = "Llama 3.3 70B (8-bit)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 76799803392

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.3-70b-fp16"
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
name = "Llama 3.3 70B (FP16)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
pretty_name = "Llama 3.3 70B (FP16)"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 144383672320

View File

@@ -1,15 +0,0 @@
short_id = "llama-3.3-70b"
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
name = "Llama 3.3 70B (4-bit)"
description = "The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)"
tags = []
[metadata]
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
pretty_name = "Llama 3.3 70B"
n_layers = 80
hidden_size = 8192
supports_tensor = true
[metadata.storage_size]
in_bytes = 40652242944

View File

@@ -1,15 +0,0 @@
short_id = "minimax-m2.1-3bit"
model_id = "mlx-community/MiniMax-M2.1-3bit"
name = "MiniMax M2.1 3bit"
description = "MiniMax M2.1 3bit"
tags = []
[metadata]
model_id = "mlx-community/MiniMax-M2.1-3bit"
pretty_name = "MiniMax M2.1 3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 100086644736

View File

@@ -1,15 +0,0 @@
short_id = "minimax-m2.1-8bit"
model_id = "mlx-community/MiniMax-M2.1-8bit"
name = "MiniMax M2.1 8bit"
description = "MiniMax M2.1 8bit"
tags = []
[metadata]
model_id = "mlx-community/MiniMax-M2.1-8bit"
pretty_name = "MiniMax M2.1 8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
[metadata.storage_size]
in_bytes = 242986745856

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-0.6b-8bit"
model_id = "mlx-community/Qwen3-0.6B-8bit"
name = "Qwen3 0.6B (8-bit)"
description = "Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-0.6B-8bit"
pretty_name = "Qwen3 0.6B (8-bit)"
n_layers = 28
hidden_size = 1024
supports_tensor = false
[metadata.storage_size]
in_bytes = 698351616

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-0.6b"
model_id = "mlx-community/Qwen3-0.6B-4bit"
name = "Qwen3 0.6B (4-bit)"
description = "Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-0.6B-4bit"
pretty_name = "Qwen3 0.6B (4-bit)"
n_layers = 28
hidden_size = 1024
supports_tensor = false
[metadata.storage_size]
in_bytes = 342884352

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-235b-a22b-4bit"
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
name = "Qwen3 235B A22B (4-bit)"
description = "Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
pretty_name = "Qwen3 235B A22B (4-bit)"
n_layers = 94
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 141733920768

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-235b-a22b-8bit"
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
name = "Qwen3 235B A22B (8-bit)"
description = "Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
pretty_name = "Qwen3 235B A22B (8-bit)"
n_layers = 94
hidden_size = 4096
supports_tensor = true
[metadata.storage_size]
in_bytes = 268435456000

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-30b-8bit"
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
name = "Qwen3 30B A3B (8-bit)"
description = "Qwen3 30B is a large language model trained on the Qwen3 30B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
pretty_name = "Qwen3 30B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 33279705088

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-30b"
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
name = "Qwen3 30B A3B (4-bit)"
description = "Qwen3 30B is a large language model trained on the Qwen3 30B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
pretty_name = "Qwen3 30B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 17612931072

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-80b-a3B-4bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
name = "Qwen3 80B A3B (4-bit)"
description = "Qwen3 80B"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
pretty_name = "Qwen3 80B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 46976204800

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-80b-a3B-8bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
name = "Qwen3 80B A3B (8-bit)"
description = "Qwen3 80B"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
pretty_name = "Qwen3 80B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-80b-a3B-thinking-4bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
name = "Qwen3 80B A3B Thinking (4-bit)"
description = "Qwen3 80B Reasoning model"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
pretty_name = "Qwen3 80B A3B (4-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-80b-a3B-thinking-8bit"
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
name = "Qwen3 80B A3B Thinking (8-bit)"
description = "Qwen3 80B Reasoning model"
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
pretty_name = "Qwen3 80B A3B (8-bit)"
n_layers = 48
hidden_size = 2048
supports_tensor = true
[metadata.storage_size]
in_bytes = 88814387200

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-coder-480b-a35b-4bit"
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
name = "Qwen3 Coder 480B A35B (4-bit)"
description = "Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
pretty_name = "Qwen3 Coder 480B A35B (4-bit)"
n_layers = 62
hidden_size = 6144
supports_tensor = true
[metadata.storage_size]
in_bytes = 289910292480

View File

@@ -1,15 +0,0 @@
short_id = "qwen3-coder-480b-a35b-8bit"
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
name = "Qwen3 Coder 480B A35B (8-bit)"
description = "Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset."
tags = []
[metadata]
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
pretty_name = "Qwen3 Coder 480B A35B (8-bit)"
n_layers = 62
hidden_size = 6144
supports_tensor = true
[metadata.storage_size]
in_bytes = 579820584960

View File

@@ -205,6 +205,14 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -218,6 +226,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -259,6 +268,20 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,24 +1,19 @@
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
import anyio
from anyio import create_task_group
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -35,6 +30,8 @@ from exo.shared.types.api import (
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
GenerationStats,
ModelList,
@@ -55,7 +52,12 @@ from exo.shared.types.commands import (
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
@@ -67,8 +69,6 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -123,6 +123,7 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -153,6 +154,20 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -381,35 +396,8 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -417,16 +405,10 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -442,11 +424,23 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -458,7 +452,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -466,7 +460,13 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -495,7 +495,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -503,7 +503,13 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -544,8 +550,6 @@ class API:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -562,17 +566,16 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
return await self._collect_chat_completion(command.command_id)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -589,10 +592,7 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
def _calculate_total_available_memory(self) -> Memory:
@@ -654,14 +654,14 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -0,0 +1,107 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -29,6 +29,11 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -1,8 +1,5 @@
from anyio import Path, open_file
import tomlkit
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.models.model_meta import get_model_meta
from exo.utils.pydantic_ext import CamelCaseModel
@@ -14,27 +11,542 @@ class ModelCard(CamelCaseModel):
tags: list[str]
metadata: ModelMetadata
@staticmethod
async def load(path: Path) -> "ModelCard":
async with await open_file(path) as f:
data = await f.read()
py = tomlkit.loads(data)
return ModelCard.model_validate(py)
async def save(self, path: Path):
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def from_hf(model_id: str) -> "ModelCard":
short_name = model_id.split("/")[-1]
return ModelCard(
short_id=short_name,
model_id=ModelId(model_id),
name=short_name,
description=f"Custom model from {model_id}",
tags=[],
metadata=await get_model_meta(model_id),
)
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
}

View File

@@ -6,6 +6,7 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
@@ -107,13 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_id,
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=False,
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -11,10 +11,21 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call"
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"

View File

@@ -22,6 +22,7 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):

View File

@@ -202,9 +202,9 @@ def tensor_auto_parallel(
segments: int = 1
def _all_to_sharded(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - all to sharded")
return weight.ndim - 1, segments
# if path.endswith("bias"):
# logger.info(f"Sharding bias for {path} - all to sharded")
# return weight.ndim - 1, segments
return max(weight.ndim - 2, 0), segments
all_to_sharded_linear_in_place = partial(
@@ -216,10 +216,10 @@ def tensor_auto_parallel(
n = group.size()
def _sharded_to_all(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - sharded to all")
weight /= n
return None
# if path.endswith("bias"):
# logger.info(f"Sharding bias for {path} - sharded to all")
# weight /= n
# return None
return -1, segments
sharded_to_all_linear_in_place = partial(

View File

@@ -2,7 +2,9 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -20,6 +22,7 @@ except ImportError:
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -81,6 +84,45 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -187,7 +229,9 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
@@ -201,7 +245,9 @@ def load_mlx_items(
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
@@ -215,6 +261,7 @@ def load_mlx_items(
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
@@ -251,7 +298,14 @@ def shard_and_load(
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# Estimate timeout based on model size
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = 60 + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)
@@ -365,6 +419,8 @@ def apply_chat_template(
tools=chat_task_data.tools,
)
logger.info(prompt)
return prompt
@@ -396,6 +452,11 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -17,15 +17,23 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
)
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main

View File

@@ -1,9 +1,21 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -11,6 +23,7 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -39,6 +52,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -48,6 +62,33 @@ from exo.worker.engines.mlx.utils_mlx import (
from exo.worker.runner.bootstrap import logger
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -109,7 +150,20 @@ def main(
)
)
model, tokenizer = load_mlx_items(bound_instance, group)
def on_model_load_timeout() -> None:
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id,
runner_status=RunnerFailed(
error_message="Model loading timed out"
),
)
)
time.sleep(0.5)
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -126,7 +180,7 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
model=cast(Model, model),
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -139,8 +193,6 @@ def main(
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -149,33 +201,47 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
assert model
assert tokenizer
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
@@ -207,6 +273,43 @@ def main(
break
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -0,0 +1,50 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()

View File

@@ -1,49 +1,64 @@
import http.client
from anyio import create_task_group, to_thread
import anyio
import httpx
from anyio import create_task_group
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
REACHABILITY_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
body = response.read().decode("utf-8").strip()
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
remote_node_id = NodeId(body)
break
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
return None
finally:
connection.close()
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
@@ -61,18 +76,33 @@ async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
for node in topology.list_nodes():
if not node.node_profile:
continue
if node.node_id == self_node_id:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
client,
)
return reachable

13
uv.lock generated
View File

@@ -236,6 +236,7 @@ dependencies = [
{ name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -247,7 +248,6 @@ dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
@@ -269,6 +269,7 @@ requires-dist = [
{ name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" },
{ name = "fastapi", specifier = ">=0.116.1" },
{ name = "filelock", specifier = ">=3.18.0" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
@@ -280,7 +281,6 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
]
@@ -1380,15 +1380,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/a1/d62dfe7376beaaf1394917e0f8e93ee5f67fea8fcf4107501db35996586b/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38337540fbbddff8e999d59970f3c6f35a82de10053206a7562f1ea02d046fa5", size = 10033429, upload-time = "2026-01-05T10:45:14.333Z" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
]
[[package]]
name = "tqdm"
version = "4.67.1"