mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-13 07:32:30 -05:00
Compare commits
17 Commits
ciaran/mes
...
e2e-tests
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8203596ab | ||
|
|
b88749a6c5 | ||
|
|
4a446b2779 | ||
|
|
a82feed8e3 | ||
|
|
da6e626f6f | ||
|
|
6950f94109 | ||
|
|
cf23916b8b | ||
|
|
d0c44273db | ||
|
|
80b29ba0d9 | ||
|
|
b6214c297f | ||
|
|
cc33213842 | ||
|
|
62e8110e97 | ||
|
|
98773437f3 | ||
|
|
a8acb3cafb | ||
|
|
a0721dbe57 | ||
|
|
50e2bcf93e | ||
|
|
7bed91c9c2 |
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@@ -0,0 +1,15 @@
|
||||
.venv/
|
||||
.direnv/
|
||||
target/
|
||||
.git/
|
||||
.idea/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
dashboard/node_modules/
|
||||
dashboard/.svelte-kit/
|
||||
dashboard/build/
|
||||
dist/
|
||||
*.pdb
|
||||
**/__pycache__
|
||||
**/.DS_Store
|
||||
.mlx_typings/
|
||||
29
.github/workflows/e2e.yml
vendored
Normal file
29
.github/workflows/e2e.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: e2e-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches:
|
||||
- staging
|
||||
- main
|
||||
|
||||
jobs:
|
||||
e2e:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
|
||||
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
|
||||
/opt/microsoft /opt/az
|
||||
docker system prune -af
|
||||
df -h /
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Run E2E tests
|
||||
run: python3 e2e/run_all.py
|
||||
75
AGENTS.md
75
AGENTS.md
@@ -119,3 +119,78 @@ From .cursorrules:
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
## Dashboard UI Testing & Screenshots
|
||||
|
||||
### Building and Running the Dashboard
|
||||
```bash
|
||||
# Build the dashboard (must be done before running exo)
|
||||
cd dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Start exo (serves the dashboard at http://localhost:52415)
|
||||
uv run exo &
|
||||
sleep 8 # Wait for server to start
|
||||
```
|
||||
|
||||
### Taking Headless Screenshots with Playwright
|
||||
Use Playwright with headless Chromium for programmatic screenshots — no manual browser interaction needed.
|
||||
|
||||
**Setup (one-time):**
|
||||
```bash
|
||||
npx --yes playwright install chromium
|
||||
cd /tmp && npm init -y && npm install playwright
|
||||
```
|
||||
|
||||
**Taking screenshots:**
|
||||
```javascript
|
||||
// Run from /tmp where playwright is installed: cd /tmp && node -e "..."
|
||||
const { chromium } = require('playwright');
|
||||
(async () => {
|
||||
const browser = await chromium.launch({ headless: true });
|
||||
const page = await browser.newPage({ viewport: { width: 1280, height: 800 } });
|
||||
await page.goto('http://localhost:52415', { waitUntil: 'networkidle' });
|
||||
await page.waitForTimeout(2000);
|
||||
|
||||
// Inject test data into localStorage if needed (e.g., recent models)
|
||||
await page.evaluate(() => {
|
||||
localStorage.setItem('exo-recent-models', JSON.stringify([
|
||||
{ modelId: 'mlx-community/Qwen3-30B-A3B-4bit', launchedAt: Date.now() },
|
||||
]));
|
||||
});
|
||||
await page.reload({ waitUntil: 'networkidle' });
|
||||
await page.waitForTimeout(2000);
|
||||
|
||||
// Interact with UI elements
|
||||
await page.locator('text=SELECT MODEL').click();
|
||||
await page.waitForTimeout(1000);
|
||||
|
||||
// Take screenshot
|
||||
await page.screenshot({ path: '/tmp/screenshot.png', fullPage: false });
|
||||
await browser.close();
|
||||
})();
|
||||
```
|
||||
|
||||
### Uploading Images to GitHub PRs
|
||||
GitHub's API doesn't support direct image upload for PR comments. Workaround:
|
||||
|
||||
1. **Commit images to the branch** (temporarily):
|
||||
```bash
|
||||
cp /tmp/screenshot.png .
|
||||
git add screenshot.png
|
||||
git commit -m "temp: add screenshots for PR"
|
||||
git push origin <branch>
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
```
|
||||
|
||||
2. **Post PR comment** referencing the raw image URL (uses permanent commit SHA so images survive deletion):
|
||||
```bash
|
||||
gh pr comment <PR_NUMBER> --body ""
|
||||
```
|
||||
|
||||
3. **Remove the images** from the branch:
|
||||
```bash
|
||||
git rm screenshot.png
|
||||
git commit -m "chore: remove temporary screenshot files"
|
||||
git push origin <branch>
|
||||
```
|
||||
The images still render in the PR comment because they reference the permanent commit SHA.
|
||||
|
||||
@@ -563,21 +563,45 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var rdmaStatusView: some View {
|
||||
let rdma = networkStatusService.status.rdmaStatus
|
||||
let rdmaStatuses = stateService.latestSnapshot?.nodeRdmaCtl ?? [:]
|
||||
let localNodeId = stateService.localNodeId
|
||||
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
|
||||
let localDevices = networkStatusService.status.localRdmaDevices
|
||||
let localPorts = networkStatusService.status.localRdmaActivePorts
|
||||
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
Text("RDMA: \(rdmaStatusText(rdma))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(rdmaStatusColor(rdma))
|
||||
if !rdma.devices.isEmpty {
|
||||
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
|
||||
if rdmaStatuses.isEmpty {
|
||||
Text("Cluster RDMA: No data")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
} else {
|
||||
Text("Cluster RDMA Status:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(Array(rdmaStatuses.keys.sorted()), id: \.self) { nodeId in
|
||||
if let status = rdmaStatuses[nodeId] {
|
||||
let nodeName =
|
||||
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
|
||||
let isLocal = nodeId == localNodeId
|
||||
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
|
||||
let statusText = status.enabled ? "Enabled" : "Disabled"
|
||||
let color: Color = status.enabled ? .green : .orange
|
||||
Text("\(prefix) \(statusText)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(color)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !localDevices.isEmpty {
|
||||
Text(" Local Devices: \(localDevices.joined(separator: ", "))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
if !rdma.activePorts.isEmpty {
|
||||
Text(" Active Ports:")
|
||||
if !localPorts.isEmpty {
|
||||
Text(" Local Active Ports:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(rdma.activePorts, id: \.device) { port in
|
||||
ForEach(localPorts, id: \.device) { port in
|
||||
Text(" \(port.device) port \(port.port): \(port.state)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.green)
|
||||
@@ -586,28 +610,6 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return "Enabled"
|
||||
case .some(false):
|
||||
return "Disabled"
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? "Not Available" : "Available"
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return .green
|
||||
case .some(false):
|
||||
return .orange
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? .secondary : .green
|
||||
}
|
||||
}
|
||||
|
||||
private var sendBugReportButton: some View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Button {
|
||||
|
||||
@@ -15,6 +15,7 @@ struct ClusterState: Decodable {
|
||||
let nodeMemory: [String: MemoryInfo]
|
||||
let nodeSystem: [String: SystemInfo]
|
||||
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
|
||||
let nodeRdmaCtl: [String: NodeRdmaCtlStatus]
|
||||
|
||||
/// Computed property for backwards compatibility - merges granular state into NodeProfile
|
||||
var nodeProfiles: [String: NodeProfile] {
|
||||
@@ -65,6 +66,10 @@ struct ClusterState: Decodable {
|
||||
try container.decodeIfPresent(
|
||||
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
|
||||
) ?? [:]
|
||||
self.nodeRdmaCtl =
|
||||
try container.decodeIfPresent(
|
||||
[String: NodeRdmaCtlStatus].self, forKey: .nodeRdmaCtl
|
||||
) ?? [:]
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
@@ -78,6 +83,7 @@ struct ClusterState: Decodable {
|
||||
case nodeMemory
|
||||
case nodeSystem
|
||||
case nodeThunderboltBridge
|
||||
case nodeRdmaCtl
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,6 +165,10 @@ struct ThunderboltBridgeStatus: Decodable {
|
||||
let serviceName: String?
|
||||
}
|
||||
|
||||
struct NodeRdmaCtlStatus: Decodable {
|
||||
let enabled: Bool
|
||||
}
|
||||
|
||||
struct MemoryInfo: Decodable {
|
||||
let ramTotal: MemoryValue?
|
||||
let ramAvailable: MemoryValue?
|
||||
|
||||
@@ -35,28 +35,18 @@ struct NetworkStatus: Equatable {
|
||||
let thunderboltBridgeState: ThunderboltState?
|
||||
let bridgeInactive: Bool?
|
||||
let interfaceStatuses: [InterfaceIpStatus]
|
||||
let rdmaStatus: RDMAStatus
|
||||
let localRdmaDevices: [String]
|
||||
let localRdmaActivePorts: [RDMAPort]
|
||||
|
||||
static let empty = NetworkStatus(
|
||||
thunderboltBridgeState: nil,
|
||||
bridgeInactive: nil,
|
||||
interfaceStatuses: [],
|
||||
rdmaStatus: .empty
|
||||
localRdmaDevices: [],
|
||||
localRdmaActivePorts: []
|
||||
)
|
||||
}
|
||||
|
||||
struct RDMAStatus: Equatable {
|
||||
let rdmaCtlEnabled: Bool?
|
||||
let devices: [String]
|
||||
let activePorts: [RDMAPort]
|
||||
|
||||
var isAvailable: Bool {
|
||||
rdmaCtlEnabled == true || !devices.isEmpty
|
||||
}
|
||||
|
||||
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
|
||||
}
|
||||
|
||||
struct RDMAPort: Equatable {
|
||||
let device: String
|
||||
let port: String
|
||||
@@ -80,31 +70,11 @@ private struct NetworkStatusFetcher {
|
||||
thunderboltBridgeState: readThunderboltBridgeState(),
|
||||
bridgeInactive: readBridgeInactive(),
|
||||
interfaceStatuses: readInterfaceStatuses(),
|
||||
rdmaStatus: readRDMAStatus()
|
||||
localRdmaDevices: readRDMADevices(),
|
||||
localRdmaActivePorts: readRDMAActivePorts()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMAStatus() -> RDMAStatus {
|
||||
let rdmaCtlEnabled = readRDMACtlEnabled()
|
||||
let devices = readRDMADevices()
|
||||
let activePorts = readRDMAActivePorts()
|
||||
return RDMAStatus(
|
||||
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
|
||||
}
|
||||
|
||||
private func readRDMACtlEnabled() -> Bool? {
|
||||
let result = runCommand(["rdma_ctl", "status"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
if output.contains("enabled") {
|
||||
return true
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private func readRDMADevices() -> [String] {
|
||||
let result = runCommand(["ibv_devices"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
|
||||
@@ -19,6 +19,11 @@ from urllib.parse import urlencode
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Backoff constants for cluster settling retry
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
@@ -388,6 +393,66 @@ class PromptSizer:
|
||||
return content, tok
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
@@ -464,6 +529,12 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -487,11 +558,6 @@ def main() -> int:
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = load_tokenizer_for_bench(full_model_id)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
@@ -503,54 +569,20 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
if not selected and args.settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + args.settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
ttftMs,
|
||||
tps,
|
||||
totalTokens,
|
||||
thinkingEnabled as thinkingEnabledStore,
|
||||
setConversationThinking,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
@@ -25,6 +27,7 @@
|
||||
autofocus?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
modelTasks?: Record<string, string[]>;
|
||||
modelCapabilities?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -34,6 +37,7 @@
|
||||
autofocus = true,
|
||||
showModelSelector = false,
|
||||
modelTasks = {},
|
||||
modelCapabilities = {},
|
||||
}: Props = $props();
|
||||
|
||||
let message = $state("");
|
||||
@@ -41,6 +45,7 @@
|
||||
let fileInputRef: HTMLInputElement | undefined = $state();
|
||||
let uploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
let isDragOver = $state(false);
|
||||
const thinkingEnabled = $derived(thinkingEnabledStore());
|
||||
let loading = $derived(isLoading());
|
||||
const currentModel = $derived(selectedChatModel());
|
||||
const instanceData = $derived(instances());
|
||||
@@ -95,6 +100,12 @@
|
||||
);
|
||||
});
|
||||
|
||||
const modelSupportsThinking = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
const caps = modelCapabilities[currentModel] || [];
|
||||
return caps.includes("thinking") && caps.includes("text");
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
currentModel !== null &&
|
||||
modelSupportsOnlyImageEditing(currentModel) &&
|
||||
@@ -282,7 +293,11 @@
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
sendMessage(content, files);
|
||||
sendMessage(
|
||||
content,
|
||||
files,
|
||||
modelSupportsThinking() ? thinkingEnabled : null,
|
||||
);
|
||||
}
|
||||
|
||||
// Refocus the textarea after sending
|
||||
@@ -520,6 +535,35 @@
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Thinking toggle -->
|
||||
{#if modelSupportsThinking()}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => setConversationThinking(!thinkingEnabled)}
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded text-xs font-mono tracking-wide transition-all duration-200 flex-shrink-0 cursor-pointer border {thinkingEnabled
|
||||
? 'bg-exo-yellow/15 border-exo-yellow/40 text-exo-yellow'
|
||||
: 'bg-exo-medium-gray/30 border-exo-medium-gray/50 text-exo-light-gray/60 hover:text-exo-light-gray'}"
|
||||
title={thinkingEnabled
|
||||
? "Thinking enabled — click to disable"
|
||||
: "Thinking disabled — click to enable"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="1.5"
|
||||
>
|
||||
<path
|
||||
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
<span>{thinkingEnabled ? "THINK" : "NO THINK"}</span>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Performance stats -->
|
||||
{#if currentTtft !== null || currentTps !== null}
|
||||
<div class="flex items-center gap-4 text-xs font-mono flex-shrink-0">
|
||||
|
||||
@@ -13,6 +13,12 @@
|
||||
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
|
||||
/>
|
||||
</svg>
|
||||
{:else if family === "recents"}
|
||||
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
|
||||
<path
|
||||
d="M13 3a9 9 0 0 0-9 9H1l3.89 3.89.07.14L9 12H6c0-3.87 3.13-7 7-7s7 3.13 7 7-3.13 7-7 7c-1.93 0-3.68-.79-4.94-2.06l-1.42 1.42A8.954 8.954 0 0 0 13 21a9 9 0 0 0 0-18zm-1 5v5l4.28 2.54.72-1.21-3.5-2.08V8H12z"
|
||||
/>
|
||||
</svg>
|
||||
{:else if family === "llama" || family === "meta"}
|
||||
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
|
||||
<path
|
||||
|
||||
@@ -5,15 +5,22 @@
|
||||
families: string[];
|
||||
selectedFamily: string | null;
|
||||
hasFavorites: boolean;
|
||||
hasRecents: boolean;
|
||||
onSelect: (family: string | null) => void;
|
||||
};
|
||||
|
||||
let { families, selectedFamily, hasFavorites, onSelect }: FamilySidebarProps =
|
||||
$props();
|
||||
let {
|
||||
families,
|
||||
selectedFamily,
|
||||
hasFavorites,
|
||||
hasRecents,
|
||||
onSelect,
|
||||
}: FamilySidebarProps = $props();
|
||||
|
||||
// Family display names
|
||||
const familyNames: Record<string, string> = {
|
||||
favorites: "Favorites",
|
||||
recents: "Recent",
|
||||
huggingface: "Hub",
|
||||
llama: "Meta",
|
||||
qwen: "Qwen",
|
||||
@@ -89,6 +96,31 @@
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Recent (only show if has recent models) -->
|
||||
{#if hasRecents}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => onSelect("recents")}
|
||||
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
|
||||
'recents'
|
||||
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
|
||||
: 'hover:bg-white/5 border-l-2 border-transparent'}"
|
||||
title="Recently launched models"
|
||||
>
|
||||
<FamilyLogos
|
||||
family="recents"
|
||||
class={selectedFamily === "recents"
|
||||
? "text-exo-yellow"
|
||||
: "text-white/50 group-hover:text-white/70"}
|
||||
/>
|
||||
<span
|
||||
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'recents'
|
||||
? 'text-exo-yellow'
|
||||
: 'text-white/40 group-hover:text-white/60'}">Recent</span
|
||||
>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- HuggingFace Hub -->
|
||||
<button
|
||||
type="button"
|
||||
|
||||
@@ -422,9 +422,16 @@
|
||||
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
|
||||
for (const edge of topology.edges) {
|
||||
const ip = edge.sendBackIp || "?";
|
||||
const iface =
|
||||
edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
|
||||
let ip: string;
|
||||
let iface: string | null;
|
||||
|
||||
if (edge.sourceRdmaIface || edge.sinkRdmaIface) {
|
||||
ip = "RDMA";
|
||||
iface = `${edge.sourceRdmaIface || "?"} \u2192 ${edge.sinkRdmaIface || "?"}`;
|
||||
} else {
|
||||
ip = edge.sendBackIp || "?";
|
||||
iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
|
||||
}
|
||||
|
||||
if (edge.source === nodeId1 && edge.target === nodeId2) {
|
||||
aToBCandidates.push({ ip, iface });
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
onToggleFavorite: (baseModelId: string) => void;
|
||||
onShowInfo: (group: ModelGroup) => void;
|
||||
downloadStatusMap?: Map<string, DownloadAvailability>;
|
||||
launchedAt?: number;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -54,6 +55,7 @@
|
||||
onToggleFavorite,
|
||||
onShowInfo,
|
||||
downloadStatusMap,
|
||||
launchedAt,
|
||||
}: ModelPickerGroupProps = $props();
|
||||
|
||||
// Group-level download status: show if any variant is downloaded
|
||||
@@ -75,6 +77,17 @@
|
||||
return `${mb}MB`;
|
||||
}
|
||||
|
||||
function timeAgo(ts: number): string {
|
||||
const seconds = Math.floor((Date.now() - ts) / 1000);
|
||||
if (seconds < 60) return "just now";
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
if (minutes < 60) return `${minutes}m ago`;
|
||||
const hours = Math.floor(minutes / 60);
|
||||
if (hours < 24) return `${hours}h ago`;
|
||||
const days = Math.floor(hours / 24);
|
||||
return `${days}d ago`;
|
||||
}
|
||||
|
||||
// Check if any variant can fit
|
||||
const anyVariantFits = $derived(
|
||||
group.variants.some((v) => canModelFit(v.id)),
|
||||
@@ -300,6 +313,13 @@
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Time ago (for recent models) -->
|
||||
{#if launchedAt}
|
||||
<span class="text-xs font-mono text-white/20 flex-shrink-0">
|
||||
{timeAgo(launchedAt)}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Download availability indicator -->
|
||||
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
|
||||
<span
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import ModelFilterPopover from "./ModelFilterPopover.svelte";
|
||||
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
|
||||
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
|
||||
import { getRecentEntries } from "$lib/stores/recents.svelte";
|
||||
|
||||
interface ModelInfo {
|
||||
id: string;
|
||||
@@ -53,6 +54,8 @@
|
||||
models: ModelInfo[];
|
||||
selectedModelId: string | null;
|
||||
favorites: Set<string>;
|
||||
recentModelIds?: string[];
|
||||
hasRecents?: boolean;
|
||||
existingModelIds: Set<string>;
|
||||
canModelFit: (modelId: string) => boolean;
|
||||
getModelFitStatus: (modelId: string) => ModelFitStatus;
|
||||
@@ -79,6 +82,8 @@
|
||||
models,
|
||||
selectedModelId,
|
||||
favorites,
|
||||
recentModelIds = [],
|
||||
hasRecents: hasRecentsTab = false,
|
||||
existingModelIds,
|
||||
canModelFit,
|
||||
getModelFitStatus,
|
||||
@@ -387,7 +392,11 @@
|
||||
// Filter by family
|
||||
if (selectedFamily === "favorites") {
|
||||
result = result.filter((g) => favorites.has(g.id));
|
||||
} else if (selectedFamily && selectedFamily !== "huggingface") {
|
||||
} else if (
|
||||
selectedFamily &&
|
||||
selectedFamily !== "huggingface" &&
|
||||
selectedFamily !== "recents"
|
||||
) {
|
||||
result = result.filter((g) => g.family === selectedFamily);
|
||||
}
|
||||
|
||||
@@ -461,6 +470,48 @@
|
||||
// Check if any favorites exist
|
||||
const hasFavorites = $derived(favorites.size > 0);
|
||||
|
||||
// Timestamp lookup for recent models
|
||||
const recentTimestamps = $derived(
|
||||
new Map(getRecentEntries().map((e) => [e.modelId, e.launchedAt])),
|
||||
);
|
||||
|
||||
// Recent models: single-variant ModelGroups in launch order
|
||||
const recentGroups = $derived.by((): ModelGroup[] => {
|
||||
if (!recentModelIds || recentModelIds.length === 0) return [];
|
||||
const result: ModelGroup[] = [];
|
||||
for (const id of recentModelIds) {
|
||||
const model = models.find((m) => m.id === id);
|
||||
if (model) {
|
||||
result.push({
|
||||
id: model.base_model || model.id,
|
||||
name: model.name || model.id,
|
||||
capabilities: model.capabilities || ["text"],
|
||||
family: model.family || "",
|
||||
variants: [model],
|
||||
smallestVariant: model,
|
||||
hasMultipleVariants: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
// Filtered recent groups (apply search query)
|
||||
const filteredRecentGroups = $derived.by((): ModelGroup[] => {
|
||||
if (!searchQuery.trim()) return recentGroups;
|
||||
const query = searchQuery.toLowerCase().trim();
|
||||
return recentGroups.filter(
|
||||
(g) =>
|
||||
g.name.toLowerCase().includes(query) ||
|
||||
g.variants.some(
|
||||
(v) =>
|
||||
v.id.toLowerCase().includes(query) ||
|
||||
(v.name || "").toLowerCase().includes(query) ||
|
||||
(v.quantization || "").toLowerCase().includes(query),
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
function toggleGroupExpanded(groupId: string) {
|
||||
const next = new Set(expandedGroups);
|
||||
if (next.has(groupId)) {
|
||||
@@ -618,6 +669,7 @@
|
||||
families={uniqueFamilies}
|
||||
{selectedFamily}
|
||||
{hasFavorites}
|
||||
hasRecents={hasRecentsTab}
|
||||
onSelect={(family) => (selectedFamily = family)}
|
||||
/>
|
||||
|
||||
@@ -725,6 +777,44 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{:else if selectedFamily === "recents"}
|
||||
<!-- Recent models view -->
|
||||
{#if filteredRecentGroups.length === 0}
|
||||
<div
|
||||
class="flex flex-col items-center justify-center h-full text-white/40 p-8"
|
||||
>
|
||||
<svg
|
||||
class="w-12 h-12 mb-3"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
d="M13 3a9 9 0 0 0-9 9H1l3.89 3.89.07.14L9 12H6c0-3.87 3.13-7 7-7s7 3.13 7 7-3.13 7-7 7c-1.93 0-3.68-.79-4.94-2.06l-1.42 1.42A8.954 8.954 0 0 0 13 21a9 9 0 0 0 0-18zm-1 5v5l4.28 2.54.72-1.21-3.5-2.08V8H12z"
|
||||
/>
|
||||
</svg>
|
||||
<p class="font-mono text-sm">
|
||||
{searchQuery
|
||||
? "No matching recent models"
|
||||
: "No recently launched models"}
|
||||
</p>
|
||||
</div>
|
||||
{:else}
|
||||
{#each filteredRecentGroups as group}
|
||||
<ModelPickerGroup
|
||||
{group}
|
||||
isExpanded={expandedGroups.has(group.id)}
|
||||
isFavorite={favorites.has(group.id)}
|
||||
{selectedModelId}
|
||||
{canModelFit}
|
||||
onToggleExpand={() => toggleGroupExpanded(group.id)}
|
||||
onSelectModel={handleSelect}
|
||||
{onToggleFavorite}
|
||||
onShowInfo={(g) => (infoGroup = g)}
|
||||
downloadStatusMap={getVariantDownloadMap(group)}
|
||||
launchedAt={recentTimestamps.get(group.variants[0]?.id ?? "")}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
{:else if filteredGroups.length === 0}
|
||||
<div
|
||||
class="flex flex-col items-center justify-center h-full text-white/40 p-8"
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
isTopologyMinimized,
|
||||
debugMode,
|
||||
nodeThunderboltBridge,
|
||||
nodeRdmaCtl,
|
||||
nodeIdentities,
|
||||
type NodeInfo,
|
||||
} from "$lib/stores/app.svelte";
|
||||
|
||||
@@ -31,6 +33,8 @@
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
const rdmaCtlData = $derived(nodeRdmaCtl());
|
||||
const identitiesData = $derived(nodeIdentities());
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -333,14 +337,27 @@
|
||||
if (edge.source === a) entry.aToB = true;
|
||||
else entry.bToA = true;
|
||||
|
||||
const ip = edge.sendBackIp || "?";
|
||||
const ifaceInfo = getInterfaceLabel(edge.source, ip);
|
||||
let ip: string;
|
||||
let ifaceLabel: string;
|
||||
let missingIface: boolean;
|
||||
|
||||
if (edge.sourceRdmaIface || edge.sinkRdmaIface) {
|
||||
ip = "RDMA";
|
||||
ifaceLabel = `${edge.sourceRdmaIface || "?"} \u2192 ${edge.sinkRdmaIface || "?"}`;
|
||||
missingIface = false;
|
||||
} else {
|
||||
ip = edge.sendBackIp || "?";
|
||||
const ifaceInfo = getInterfaceLabel(edge.source, ip);
|
||||
ifaceLabel = ifaceInfo.label;
|
||||
missingIface = ifaceInfo.missing;
|
||||
}
|
||||
|
||||
entry.connections.push({
|
||||
from: edge.source,
|
||||
to: edge.target,
|
||||
ip,
|
||||
ifaceLabel: ifaceInfo.label,
|
||||
missingIface: ifaceInfo.missing,
|
||||
ifaceLabel,
|
||||
missingIface,
|
||||
});
|
||||
pairMap.set(key, entry);
|
||||
});
|
||||
@@ -1120,15 +1137,17 @@
|
||||
.text(` (${ramUsagePercent.toFixed(0)}%)`);
|
||||
}
|
||||
|
||||
// Debug mode: Show TB bridge status
|
||||
// Debug mode: Show TB bridge and RDMA status
|
||||
if (debugEnabled) {
|
||||
let debugLabelY =
|
||||
nodeInfo.y +
|
||||
iconBaseHeight / 2 +
|
||||
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
|
||||
const debugFontSize = showFullLabels ? 9 : 7;
|
||||
const debugLineHeight = showFullLabels ? 11 : 9;
|
||||
|
||||
const tbStatus = tbBridgeData[nodeInfo.id];
|
||||
if (tbStatus) {
|
||||
const tbY =
|
||||
nodeInfo.y +
|
||||
iconBaseHeight / 2 +
|
||||
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
|
||||
const tbFontSize = showFullLabels ? 9 : 7;
|
||||
const tbColor = tbStatus.enabled
|
||||
? "rgba(234,179,8,0.9)"
|
||||
: "rgba(100,100,100,0.7)";
|
||||
@@ -1136,12 +1155,46 @@
|
||||
nodeG
|
||||
.append("text")
|
||||
.attr("x", nodeInfo.x)
|
||||
.attr("y", tbY)
|
||||
.attr("y", debugLabelY)
|
||||
.attr("text-anchor", "middle")
|
||||
.attr("fill", tbColor)
|
||||
.attr("font-size", tbFontSize)
|
||||
.attr("font-size", debugFontSize)
|
||||
.attr("font-family", "SF Mono, Monaco, monospace")
|
||||
.text(tbText);
|
||||
debugLabelY += debugLineHeight;
|
||||
}
|
||||
|
||||
const rdmaStatus = rdmaCtlData[nodeInfo.id];
|
||||
if (rdmaStatus !== undefined) {
|
||||
const rdmaColor = rdmaStatus.enabled
|
||||
? "rgba(74,222,128,0.9)"
|
||||
: "rgba(100,100,100,0.7)";
|
||||
const rdmaText = rdmaStatus.enabled ? "RDMA:ON" : "RDMA:OFF";
|
||||
nodeG
|
||||
.append("text")
|
||||
.attr("x", nodeInfo.x)
|
||||
.attr("y", debugLabelY)
|
||||
.attr("text-anchor", "middle")
|
||||
.attr("fill", rdmaColor)
|
||||
.attr("font-size", debugFontSize)
|
||||
.attr("font-family", "SF Mono, Monaco, monospace")
|
||||
.text(rdmaText);
|
||||
debugLabelY += debugLineHeight;
|
||||
}
|
||||
|
||||
const identity = identitiesData[nodeInfo.id];
|
||||
if (identity?.osVersion) {
|
||||
nodeG
|
||||
.append("text")
|
||||
.attr("x", nodeInfo.x)
|
||||
.attr("y", debugLabelY)
|
||||
.attr("text-anchor", "middle")
|
||||
.attr("fill", "rgba(179,179,179,0.7)")
|
||||
.attr("font-size", debugFontSize)
|
||||
.attr("font-family", "SF Mono, Monaco, monospace")
|
||||
.text(
|
||||
`macOS ${identity.osVersion}${identity.osBuildVersion ? ` (${identity.osBuildVersion})` : ""}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -49,6 +49,7 @@ export interface NodeInfo {
|
||||
};
|
||||
last_macmon_update: number;
|
||||
friendly_name?: string;
|
||||
os_version?: string;
|
||||
}
|
||||
|
||||
export interface TopologyEdge {
|
||||
@@ -56,6 +57,8 @@ export interface TopologyEdge {
|
||||
target: string;
|
||||
sendBackIp?: string;
|
||||
sendBackInterface?: string;
|
||||
sourceRdmaIface?: string;
|
||||
sinkRdmaIface?: string;
|
||||
}
|
||||
|
||||
export interface TopologyData {
|
||||
@@ -76,6 +79,8 @@ interface RawNodeIdentity {
|
||||
modelId?: string;
|
||||
chipId?: string;
|
||||
friendlyName?: string;
|
||||
osVersion?: string;
|
||||
osBuildVersion?: string;
|
||||
}
|
||||
|
||||
interface RawMemoryUsage {
|
||||
@@ -225,6 +230,19 @@ interface RawStateResponse {
|
||||
nodeMemory?: Record<string, RawMemoryUsage>;
|
||||
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
|
||||
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
|
||||
// Thunderbolt identifiers per node
|
||||
nodeThunderbolt?: Record<
|
||||
string,
|
||||
{
|
||||
interfaces: Array<{
|
||||
rdmaInterface: string;
|
||||
domainUuid: string;
|
||||
linkSpeed: string;
|
||||
}>;
|
||||
}
|
||||
>;
|
||||
// RDMA ctl status per node
|
||||
nodeRdmaCtl?: Record<string, { enabled: boolean }>;
|
||||
// Thunderbolt bridge status per node
|
||||
nodeThunderboltBridge?: Record<
|
||||
string,
|
||||
@@ -278,6 +296,7 @@ export interface Conversation {
|
||||
modelId: string | null;
|
||||
sharding: string | null;
|
||||
instanceType: string | null;
|
||||
enableThinking: boolean | null;
|
||||
}
|
||||
|
||||
const STORAGE_KEY = "exo-conversations";
|
||||
@@ -425,6 +444,7 @@ function transformTopology(
|
||||
},
|
||||
last_macmon_update: Date.now() / 1000,
|
||||
friendly_name: identity?.friendlyName,
|
||||
os_version: identity?.osVersion,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -437,6 +457,8 @@ function transformTopology(
|
||||
if (!Array.isArray(edgeList)) continue;
|
||||
for (const edge of edgeList) {
|
||||
let sendBackIp: string | undefined;
|
||||
let sourceRdmaIface: string | undefined;
|
||||
let sinkRdmaIface: string | undefined;
|
||||
if (edge && typeof edge === "object" && "sinkMultiaddr" in edge) {
|
||||
const multiaddr = edge.sinkMultiaddr;
|
||||
if (multiaddr) {
|
||||
@@ -444,10 +466,23 @@ function transformTopology(
|
||||
multiaddr.ip_address ||
|
||||
extractIpFromMultiaddr(multiaddr.address);
|
||||
}
|
||||
} else if (
|
||||
edge &&
|
||||
typeof edge === "object" &&
|
||||
"sourceRdmaIface" in edge
|
||||
) {
|
||||
sourceRdmaIface = edge.sourceRdmaIface;
|
||||
sinkRdmaIface = edge.sinkRdmaIface;
|
||||
}
|
||||
|
||||
if (nodes[source] && nodes[sink] && source !== sink) {
|
||||
edges.push({ source, target: sink, sendBackIp });
|
||||
edges.push({
|
||||
source,
|
||||
target: sink,
|
||||
sendBackIp,
|
||||
sourceRdmaIface,
|
||||
sinkRdmaIface,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -490,12 +525,32 @@ class AppStore {
|
||||
instances = $state<Record<string, unknown>>({});
|
||||
runners = $state<Record<string, unknown>>({});
|
||||
downloads = $state<Record<string, unknown[]>>({});
|
||||
nodeDisk = $state<
|
||||
Record<
|
||||
string,
|
||||
{ total: { inBytes: number }; available: { inBytes: number } }
|
||||
>
|
||||
>({});
|
||||
placementPreviews = $state<PlacementPreview[]>([]);
|
||||
selectedPreviewModelId = $state<string | null>(null);
|
||||
isLoadingPreviews = $state(false);
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderbolt = $state<
|
||||
Record<
|
||||
string,
|
||||
{
|
||||
interfaces: Array<{
|
||||
rdmaInterface: string;
|
||||
domainUuid: string;
|
||||
linkSpeed: string;
|
||||
}>;
|
||||
}
|
||||
>
|
||||
>({});
|
||||
nodeRdmaCtl = $state<Record<string, { enabled: boolean }>>({});
|
||||
nodeThunderboltBridge = $state<
|
||||
Record<
|
||||
string,
|
||||
@@ -551,6 +606,7 @@ class AppStore {
|
||||
modelId: conversation.modelId ?? null,
|
||||
sharding: conversation.sharding ?? null,
|
||||
instanceType: conversation.instanceType ?? null,
|
||||
enableThinking: conversation.enableThinking ?? null,
|
||||
}));
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -740,6 +796,7 @@ class AppStore {
|
||||
modelId: derivedModelId,
|
||||
sharding: derivedSharding,
|
||||
instanceType: derivedInstanceType,
|
||||
enableThinking: null,
|
||||
};
|
||||
|
||||
this.conversations.unshift(conversation);
|
||||
@@ -765,6 +822,7 @@ class AppStore {
|
||||
this.hasStartedChat = true;
|
||||
this.isTopologyMinimized = true;
|
||||
this.isSidebarOpen = true; // Auto-open sidebar when chatting
|
||||
this.thinkingEnabled = conversation.enableThinking ?? true;
|
||||
this.refreshConversationModelFromInstances();
|
||||
|
||||
return true;
|
||||
@@ -1206,6 +1264,15 @@ class AppStore {
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
if (data.nodeDisk) {
|
||||
this.nodeDisk = data.nodeDisk;
|
||||
}
|
||||
// Node identities (for OS version mismatch detection)
|
||||
this.nodeIdentities = data.nodeIdentities ?? {};
|
||||
// Thunderbolt identifiers per node
|
||||
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
|
||||
// RDMA ctl status per node
|
||||
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
|
||||
// Thunderbolt bridge cycles
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
|
||||
// Thunderbolt bridge status per node
|
||||
@@ -1869,6 +1936,11 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether thinking is enabled for the current conversation
|
||||
*/
|
||||
thinkingEnabled = $state(true);
|
||||
|
||||
/**
|
||||
* Selected model for chat (can be set by the UI)
|
||||
*/
|
||||
@@ -2047,6 +2119,7 @@ class AppStore {
|
||||
textContent?: string;
|
||||
preview?: string;
|
||||
}[],
|
||||
enableThinking?: boolean | null,
|
||||
): Promise<void> {
|
||||
if ((!content.trim() && (!files || files.length === 0)) || this.isLoading)
|
||||
return;
|
||||
@@ -2194,6 +2267,9 @@ class AppStore {
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
...(enableThinking != null && {
|
||||
enable_thinking: enableThinking,
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -2852,6 +2928,18 @@ class AppStore {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the thinking preference for the active conversation
|
||||
*/
|
||||
setConversationThinking(enabled: boolean) {
|
||||
this.thinkingEnabled = enabled;
|
||||
const conv = this.getActiveConversation();
|
||||
if (conv) {
|
||||
conv.enableThinking = enabled;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a download on a specific node
|
||||
*/
|
||||
@@ -2958,12 +3046,14 @@ export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const runners = () => appStore.runners;
|
||||
export const downloads = () => appStore.downloads;
|
||||
export const nodeDisk = () => appStore.nodeDisk;
|
||||
export const placementPreviews = () => appStore.placementPreviews;
|
||||
export const selectedPreviewModelId = () => appStore.selectedPreviewModelId;
|
||||
export const isLoadingPreviews = () => appStore.isLoadingPreviews;
|
||||
export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const thinkingEnabled = () => appStore.thinkingEnabled;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
@@ -2979,7 +3069,8 @@ export const sendMessage = (
|
||||
textContent?: string;
|
||||
preview?: string;
|
||||
}[],
|
||||
) => appStore.sendMessage(content, files);
|
||||
enableThinking?: boolean | null,
|
||||
) => appStore.sendMessage(content, files, enableThinking);
|
||||
export const generateImage = (prompt: string, modelId?: string) =>
|
||||
appStore.generateImage(prompt, modelId);
|
||||
export const editImage = (
|
||||
@@ -3022,6 +3113,8 @@ export const deleteAllConversations = () => appStore.deleteAllConversations();
|
||||
export const renameConversation = (id: string, name: string) =>
|
||||
appStore.renameConversation(id, name);
|
||||
export const getActiveConversation = () => appStore.getActiveConversation();
|
||||
export const setConversationThinking = (enabled: boolean) =>
|
||||
appStore.setConversationThinking(enabled);
|
||||
|
||||
// Sidebar actions
|
||||
export const isSidebarOpen = () => appStore.isSidebarOpen;
|
||||
@@ -3038,7 +3131,12 @@ export const setChatSidebarVisible = (visible: boolean) =>
|
||||
appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
// Thunderbolt bridge status
|
||||
// Node identities (for OS version mismatch detection)
|
||||
export const nodeIdentities = () => appStore.nodeIdentities;
|
||||
|
||||
// Thunderbolt & RDMA status
|
||||
export const nodeThunderbolt = () => appStore.nodeThunderbolt;
|
||||
export const nodeRdmaCtl = () => appStore.nodeRdmaCtl;
|
||||
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
|
||||
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
|
||||
|
||||
|
||||
75
dashboard/src/lib/stores/recents.svelte.ts
Normal file
75
dashboard/src/lib/stores/recents.svelte.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
/**
|
||||
* RecentsStore - Manages recently launched models with localStorage persistence
|
||||
*/
|
||||
|
||||
import { browser } from "$app/environment";
|
||||
|
||||
const RECENTS_KEY = "exo-recent-models";
|
||||
const MAX_RECENT_MODELS = 20;
|
||||
|
||||
interface RecentEntry {
|
||||
modelId: string;
|
||||
launchedAt: number;
|
||||
}
|
||||
|
||||
class RecentsStore {
|
||||
recents = $state<RecentEntry[]>([]);
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
this.loadFromStorage();
|
||||
}
|
||||
}
|
||||
|
||||
private loadFromStorage() {
|
||||
try {
|
||||
const stored = localStorage.getItem(RECENTS_KEY);
|
||||
if (stored) {
|
||||
const parsed = JSON.parse(stored) as RecentEntry[];
|
||||
this.recents = parsed;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to load recent models:", error);
|
||||
}
|
||||
}
|
||||
|
||||
private saveToStorage() {
|
||||
try {
|
||||
localStorage.setItem(RECENTS_KEY, JSON.stringify(this.recents));
|
||||
} catch (error) {
|
||||
console.error("Failed to save recent models:", error);
|
||||
}
|
||||
}
|
||||
|
||||
recordLaunch(modelId: string) {
|
||||
// Remove existing entry for this model (if any) to move it to top
|
||||
const filtered = this.recents.filter((r) => r.modelId !== modelId);
|
||||
// Prepend new entry
|
||||
const next = [{ modelId, launchedAt: Date.now() }, ...filtered];
|
||||
// Cap at max
|
||||
this.recents = next.slice(0, MAX_RECENT_MODELS);
|
||||
this.saveToStorage();
|
||||
}
|
||||
|
||||
getRecentModelIds(): string[] {
|
||||
return this.recents.map((r) => r.modelId);
|
||||
}
|
||||
|
||||
hasAny(): boolean {
|
||||
return this.recents.length > 0;
|
||||
}
|
||||
|
||||
clearAll() {
|
||||
this.recents = [];
|
||||
this.saveToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
export const recentsStore = new RecentsStore();
|
||||
|
||||
export const hasRecents = () => recentsStore.hasAny();
|
||||
export const getRecentModelIds = () => recentsStore.getRecentModelIds();
|
||||
export const getRecentEntries = () => recentsStore.recents;
|
||||
export const recordRecentLaunch = (modelId: string) =>
|
||||
recentsStore.recordLaunch(modelId);
|
||||
export const clearRecents = () => recentsStore.clearAll();
|
||||
@@ -12,6 +12,11 @@
|
||||
toggleFavorite,
|
||||
getFavoritesSet,
|
||||
} from "$lib/stores/favorites.svelte";
|
||||
import {
|
||||
hasRecents,
|
||||
getRecentModelIds,
|
||||
recordRecentLaunch,
|
||||
} from "$lib/stores/recents.svelte";
|
||||
import {
|
||||
hasStartedChat,
|
||||
isTopologyMinimized,
|
||||
@@ -37,8 +42,11 @@
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
nodeThunderbolt,
|
||||
nodeRdmaCtl,
|
||||
thunderboltBridgeCycles,
|
||||
nodeThunderboltBridge,
|
||||
nodeIdentities,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview,
|
||||
} from "$lib/stores/app.svelte";
|
||||
@@ -62,8 +70,50 @@
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
const identitiesData = $derived(nodeIdentities());
|
||||
const tbIdentifiers = $derived(nodeThunderbolt());
|
||||
const rdmaCtlData = $derived(nodeRdmaCtl());
|
||||
const nodeFilter = $derived(previewNodeFilter());
|
||||
|
||||
// Detect macOS version mismatches across cluster nodes
|
||||
const macosVersionMismatch = $derived.by(() => {
|
||||
if (!identitiesData) return null;
|
||||
const entries = Object.entries(identitiesData);
|
||||
// Filter to macOS nodes (version starts with a digit, e.g. "15.3")
|
||||
const macosNodes = entries.filter(([_, id]) => {
|
||||
const v = id.osVersion;
|
||||
return v && v !== "Unknown" && /^\d/.test(v);
|
||||
});
|
||||
if (macosNodes.length < 2) return null;
|
||||
// Compare on buildVersion for precise mismatch detection
|
||||
const buildVersions = new Set(
|
||||
macosNodes.map(([_, id]) => id.osBuildVersion ?? id.osVersion),
|
||||
);
|
||||
if (buildVersions.size <= 1) return null;
|
||||
return macosNodes.map(([nodeId, id]) => ({
|
||||
nodeId,
|
||||
friendlyName: getNodeName(nodeId),
|
||||
version: id.osVersion!,
|
||||
buildVersion: id.osBuildVersion ?? "Unknown",
|
||||
}));
|
||||
});
|
||||
|
||||
// Detect TB5 nodes where RDMA is not enabled
|
||||
const tb5WithoutRdma = $derived.by(() => {
|
||||
const rdmaCtl = rdmaCtlData;
|
||||
if (!rdmaCtl) return false;
|
||||
const ids = tbIdentifiers;
|
||||
if (!ids) return false;
|
||||
// Find nodes with TB5 hardware (any TB interface)
|
||||
const tb5NodeIds = Object.entries(ids)
|
||||
.filter(([_, node]) => node.interfaces.length > 0)
|
||||
.map(([id]) => id);
|
||||
if (tb5NodeIds.length < 2) return false;
|
||||
// At least one TB5 node has RDMA disabled
|
||||
return tb5NodeIds.some((id) => rdmaCtl[id]?.enabled !== true);
|
||||
});
|
||||
let tb5InfoDismissed = $state(false);
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -96,6 +146,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
// Warning icon SVG path (reused across warning snippets)
|
||||
const warningIconPath =
|
||||
"M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z";
|
||||
const infoIconPath =
|
||||
"M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z";
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
@@ -134,6 +190,19 @@
|
||||
return tasks;
|
||||
});
|
||||
|
||||
const modelCapabilities = $derived(() => {
|
||||
const caps: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.capabilities && model.capabilities.length > 0) {
|
||||
caps[model.id] = model.capabilities;
|
||||
if (model.hugging_face_id) {
|
||||
caps[model.hugging_face_id] = model.capabilities;
|
||||
}
|
||||
}
|
||||
}
|
||||
return caps;
|
||||
});
|
||||
|
||||
// Helper to check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const model = models.find(
|
||||
@@ -232,6 +301,10 @@
|
||||
// Favorites state (reactive)
|
||||
const favoritesSet = $derived(getFavoritesSet());
|
||||
|
||||
// Recent models state (reactive)
|
||||
const recentModelIds = $derived(getRecentModelIds());
|
||||
const showRecentsTab = $derived(hasRecents());
|
||||
|
||||
// Slider dragging state
|
||||
let isDraggingSlider = $state(false);
|
||||
let sliderTrackElement: HTMLDivElement | null = $state(null);
|
||||
@@ -661,6 +734,9 @@
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
// Record the launch in recent models history
|
||||
recordRecentLaunch(modelId);
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
const scrollToBottom = () => {
|
||||
@@ -1688,6 +1764,249 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
{#snippet clusterWarnings()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="group relative" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected via
|
||||
Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if macosVersionMismatch}
|
||||
<div class="group relative" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
INCOMPATIBLE macOS VERSIONS
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
Nodes in this cluster are running different macOS versions. This
|
||||
may cause inference compatibility issues.
|
||||
</p>
|
||||
<div class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Node versions:</span>
|
||||
{#each macosVersionMismatch as node}
|
||||
<div class="ml-2">
|
||||
{node.friendlyName} — macOS {node.version} ({node.buildVersion})
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
<p class="text-xs text-white/60">
|
||||
<span class="text-yellow-300">Suggested action:</span> Update all nodes
|
||||
to the same macOS version for best compatibility.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if tb5WithoutRdma && !tb5InfoDismissed}
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-blue-400/50 bg-blue-400/10 backdrop-blur-sm"
|
||||
role="status"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-blue-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={infoIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-blue-200"> RDMA AVAILABLE </span>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => (tb5InfoDismissed = true)}
|
||||
class="ml-1 text-blue-300/60 hover:text-blue-200 transition-colors cursor-pointer"
|
||||
title="Dismiss"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
{#snippet clusterWarningsCompact()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
<div class="absolute top-2 left-2 flex flex-col gap-1">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
|
||||
title="Thunderbolt Bridge cycle detected"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-yellow-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-yellow-200">TB CYCLE</span>
|
||||
</div>
|
||||
{/if}
|
||||
{#if macosVersionMismatch}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
|
||||
title="Incompatible macOS versions detected"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-yellow-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-yellow-200"
|
||||
>macOS MISMATCH</span
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{#if tb5WithoutRdma && !tb5InfoDismissed}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-blue-400/50 bg-blue-400/10 backdrop-blur-sm"
|
||||
title="Thunderbolt 5 detected — RDMA can be enabled for better performance"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-blue-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={infoIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-blue-200">RDMA AVAILABLE</span
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
<!-- Global event listeners for slider dragging -->
|
||||
<svelte:window
|
||||
onmousemove={handleSliderMouseMove}
|
||||
@@ -1755,17 +2074,40 @@
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="absolute top-4 left-4 group" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
{@render clusterWarnings()}
|
||||
|
||||
<!-- TB5 RDMA Available Info -->
|
||||
{#if tb5WithoutRdma && !tb5InfoDismissed}
|
||||
<div
|
||||
class="absolute left-4 flex items-center gap-2 px-3 py-2 rounded border border-blue-400/50 bg-blue-400/10 backdrop-blur-sm"
|
||||
class:top-16={tbBridgeCycles.length > 0}
|
||||
class:top-4={tbBridgeCycles.length === 0}
|
||||
role="status"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-blue-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-blue-200">
|
||||
RDMA AVAILABLE
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => (tb5InfoDismissed = true)}
|
||||
class="ml-1 text-blue-300/60 hover:text-blue-200 transition-colors cursor-pointer"
|
||||
title="Dismiss"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
@@ -1774,60 +2116,10 @@
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected
|
||||
via Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -1874,17 +2166,21 @@
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="absolute top-4 left-4 group" role="alert">
|
||||
{@render clusterWarnings()}
|
||||
|
||||
<!-- TB5 RDMA Available Info -->
|
||||
{#if tb5WithoutRdma && !tb5InfoDismissed}
|
||||
<div
|
||||
class="absolute left-4 group"
|
||||
class:top-16={tbBridgeCycles.length > 0}
|
||||
class:top-4={tbBridgeCycles.length === 0}
|
||||
role="status"
|
||||
>
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-blue-400/50 bg-blue-400/10 backdrop-blur-sm"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
class="w-5 h-5 text-blue-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
@@ -1893,60 +2189,62 @@
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
<span class="text-sm font-mono text-blue-200">
|
||||
RDMA AVAILABLE
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected
|
||||
via Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
onclick={() => (tb5InfoDismissed = true)}
|
||||
class="ml-1 text-blue-300/60 hover:text-blue-200 transition-colors cursor-pointer"
|
||||
title="Dismiss"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-blue-400/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
Thunderbolt 5 hardware detected on multiple nodes. Enable
|
||||
RDMA for significantly faster inter-node communication.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1.5">
|
||||
<span class="text-blue-300">To enable:</span>
|
||||
</p>
|
||||
<ol
|
||||
class="text-xs text-white/60 list-decimal list-inside space-y-0.5 mb-1.5"
|
||||
>
|
||||
<li>Connect nodes with TB5 cables</li>
|
||||
<li>Boot to Recovery (hold power 10s → Options)</li>
|
||||
<li>
|
||||
Run
|
||||
<code class="text-blue-300 bg-blue-400/10 px-1 rounded"
|
||||
>rdma_ctl enable</code
|
||||
>
|
||||
</li>
|
||||
<li>Reboot</li>
|
||||
</ol>
|
||||
<p class="text-xs text-white/40">
|
||||
Requires macOS 26.2+, TB5 cables, and matching OS versions.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -1985,6 +2283,7 @@
|
||||
showHelperText={false}
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
modelCapabilities={modelCapabilities()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -2764,6 +3063,7 @@
|
||||
placeholder="Ask anything"
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
modelCapabilities={modelCapabilities()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -2804,30 +3104,7 @@
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
|
||||
title="Thunderbolt Bridge cycle detected - click to view details"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-yellow-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-yellow-200"
|
||||
>TB CYCLE</span
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{@render clusterWarningsCompact()}
|
||||
</div>
|
||||
</button>
|
||||
|
||||
@@ -3283,6 +3560,8 @@
|
||||
{models}
|
||||
{selectedModelId}
|
||||
favorites={favoritesSet}
|
||||
{recentModelIds}
|
||||
hasRecents={showRecentsTab}
|
||||
existingModelIds={new Set(models.map((m) => m.id))}
|
||||
canModelFit={(modelId) => {
|
||||
const model = models.find((m) => m.id === modelId);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import {
|
||||
topologyData,
|
||||
downloads,
|
||||
nodeDisk,
|
||||
type DownloadProgress,
|
||||
refreshState,
|
||||
lastUpdate as lastUpdateStore,
|
||||
@@ -37,10 +38,13 @@
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
models: ModelEntry[];
|
||||
diskAvailable?: number;
|
||||
diskTotal?: number;
|
||||
};
|
||||
|
||||
const data = $derived(topologyData());
|
||||
const downloadsData = $derived(downloads());
|
||||
const nodeDiskData = $derived(nodeDisk());
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -327,10 +331,17 @@
|
||||
];
|
||||
}
|
||||
|
||||
// Get disk info for this node
|
||||
const diskInfo = nodeDiskData?.[nodeId];
|
||||
const diskAvailable = diskInfo?.available?.inBytes;
|
||||
const diskTotal = diskInfo?.total?.inBytes;
|
||||
|
||||
built.push({
|
||||
nodeId,
|
||||
nodeName: getNodeLabel(nodeId),
|
||||
models,
|
||||
diskAvailable,
|
||||
diskTotal,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -417,6 +428,14 @@
|
||||
<div class="text-xs text-exo-light-gray font-mono truncate">
|
||||
{node.nodeId}
|
||||
</div>
|
||||
<div class="text-xs text-exo-light-gray font-mono mt-1">
|
||||
{formatBytes(
|
||||
node.models
|
||||
.filter((m) => m.status === "completed")
|
||||
.reduce((sum, m) => sum + m.totalBytes, 0),
|
||||
)} models{#if node.diskAvailable != null}
|
||||
- {formatBytes(node.diskAvailable)} free{/if}
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right"
|
||||
@@ -429,13 +448,6 @@
|
||||
/ {node.models.length} models</span
|
||||
>
|
||||
</div>
|
||||
<div class="text-exo-light-gray normal-case tracking-normal">
|
||||
{formatBytes(
|
||||
node.models
|
||||
.filter((m) => m.status === "completed")
|
||||
.reduce((sum, m) => sum + m.totalBytes, 0),
|
||||
)} on disk
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
53
e2e/Dockerfile
Normal file
53
e2e/Dockerfile
Normal file
@@ -0,0 +1,53 @@
|
||||
# Stage 1: Build the dashboard
|
||||
FROM node:22-slim AS dashboard
|
||||
WORKDIR /app/dashboard
|
||||
COPY dashboard/package.json dashboard/package-lock.json ./
|
||||
RUN npm ci
|
||||
COPY dashboard/ .
|
||||
RUN npm run build
|
||||
|
||||
# Stage 2: Build and run exo
|
||||
FROM python:3.13-slim
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
curl \
|
||||
protobuf-compiler \
|
||||
iptables \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust nightly
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
# Install uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files first for better layer caching
|
||||
COPY pyproject.toml Cargo.toml uv.lock README.md ./
|
||||
COPY rust/ ./rust/
|
||||
COPY bench/pyproject.toml ./bench/pyproject.toml
|
||||
|
||||
# Copy source and resources
|
||||
COPY src/ ./src/
|
||||
COPY resources/ ./resources/
|
||||
|
||||
# Copy built dashboard from stage 1
|
||||
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
|
||||
|
||||
# Install Python deps and build Rust bindings, then clean up build artifacts
|
||||
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
|
||||
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
|
||||
|
||||
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
|
||||
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
|
||||
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
|
||||
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
|
||||
chmod +x /usr/bin/g++
|
||||
|
||||
CMD [".venv/bin/exo", "-v"]
|
||||
182
e2e/conftest.py
Normal file
182
e2e/conftest.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Shared E2E test infrastructure for exo cluster tests."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
E2E_DIR = Path(__file__).parent.resolve()
|
||||
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
|
||||
|
||||
|
||||
class Cluster:
|
||||
"""Async wrapper around a docker compose exo cluster."""
|
||||
|
||||
def __init__(self, name: str, overrides: list[str] | None = None):
|
||||
self.name = name
|
||||
self.project = f"e2e-{name}"
|
||||
compose_files = [str(E2E_DIR / "docker-compose.yml")]
|
||||
for path in overrides or []:
|
||||
compose_files.append(str(E2E_DIR / path))
|
||||
self._compose_base = [
|
||||
"docker",
|
||||
"compose",
|
||||
"-p",
|
||||
self.project,
|
||||
*[arg for f in compose_files for arg in ("-f", f)],
|
||||
]
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
await self.stop()
|
||||
|
||||
async def _run(self, *args: str, check: bool = True) -> str:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*self._compose_base,
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode()
|
||||
if check and proc.returncode != 0:
|
||||
print(output, file=sys.stderr)
|
||||
raise RuntimeError(
|
||||
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
|
||||
)
|
||||
return output
|
||||
|
||||
async def build(self):
|
||||
print(" Building images...")
|
||||
await self._run("build", "--quiet")
|
||||
|
||||
async def start(self):
|
||||
print(" Starting cluster...")
|
||||
await self._run("up", "-d")
|
||||
|
||||
async def stop(self):
|
||||
print(" Cleaning up...")
|
||||
await self._run("down", "--timeout", "5", check=False)
|
||||
|
||||
async def logs(self) -> str:
|
||||
return await self._run("logs", check=False)
|
||||
|
||||
async def exec(
|
||||
self, service: str, *cmd: str, check: bool = True
|
||||
) -> tuple[int, str]:
|
||||
"""Run a command inside a running container. Returns (returncode, output)."""
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*self._compose_base,
|
||||
"exec",
|
||||
"-T",
|
||||
service,
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode()
|
||||
if check and proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
|
||||
)
|
||||
return proc.returncode, output
|
||||
|
||||
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
|
||||
"""Poll check_fn every 2s until it returns True or timeout expires."""
|
||||
print(f" Waiting for {description}...")
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if await check_fn():
|
||||
print(f" {description}")
|
||||
return
|
||||
await asyncio.sleep(2)
|
||||
output = await self.logs()
|
||||
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
|
||||
raise TimeoutError(f"Timed out waiting for {description}")
|
||||
|
||||
async def assert_healthy(self):
|
||||
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
|
||||
|
||||
async def both_nodes_started():
|
||||
log = await self.logs()
|
||||
return log.count("Starting node") >= 2
|
||||
|
||||
async def nodes_discovered():
|
||||
log = await self.logs()
|
||||
return log.count("ConnectionMessageType.Connected") >= 2
|
||||
|
||||
async def master_elected():
|
||||
log = await self.logs()
|
||||
return "demoting self" in log
|
||||
|
||||
async def api_responding():
|
||||
try:
|
||||
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
|
||||
return resp.status == 200
|
||||
except (URLError, OSError):
|
||||
return False
|
||||
|
||||
await self.wait_for("Both nodes started", both_nodes_started)
|
||||
await self.wait_for("Nodes discovered each other", nodes_discovered)
|
||||
await self.wait_for("Master election resolved", master_elected)
|
||||
await self.wait_for("API responding", api_responding)
|
||||
|
||||
async def _api(
|
||||
self, method: str, path: str, body: dict | None = None, timeout: int = 30
|
||||
) -> dict:
|
||||
"""Make an API request to the cluster. Returns parsed JSON."""
|
||||
url = f"http://localhost:52415{path}"
|
||||
data = json.dumps(body).encode() if body else None
|
||||
req = Request(
|
||||
url, data=data, headers={"Content-Type": "application/json"}, method=method
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
resp_bytes = await loop.run_in_executor(
|
||||
None, lambda: urlopen(req, timeout=timeout).read()
|
||||
)
|
||||
return json.loads(resp_bytes)
|
||||
|
||||
async def place_model(self, model: str, timeout: int = 600):
|
||||
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
|
||||
await self._api("POST", "/place_instance", {"model_id": model})
|
||||
|
||||
async def model_ready():
|
||||
try:
|
||||
resp = await self._api("GET", "/v1/models")
|
||||
return any(m.get("id") == model for m in resp.get("data", []))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
|
||||
|
||||
async def chat(
|
||||
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
|
||||
) -> dict:
|
||||
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
|
||||
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
last_error = None
|
||||
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
try:
|
||||
req = Request(
|
||||
"http://localhost:52415/v1/chat/completions",
|
||||
data=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
resp_bytes = await loop.run_in_executor(
|
||||
None, lambda r=req: urlopen(r, timeout=300).read()
|
||||
)
|
||||
return json.loads(resp_bytes)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
await asyncio.sleep(5)
|
||||
|
||||
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")
|
||||
18
e2e/docker-compose.yml
Normal file
18
e2e/docker-compose.yml
Normal file
@@ -0,0 +1,18 @@
|
||||
services:
|
||||
exo-node-1:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile
|
||||
environment:
|
||||
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||
command: [".venv/bin/exo", "-v"]
|
||||
ports:
|
||||
- "52415:52415"
|
||||
|
||||
exo-node-2:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile
|
||||
environment:
|
||||
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||
command: [".venv/bin/exo", "-v"]
|
||||
75
e2e/run_all.py
Normal file
75
e2e/run_all.py
Normal file
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Discovers and runs all E2E tests in e2e/test_*.py.
|
||||
|
||||
Tests with '# slow' on the first line of their docstring are skipped
|
||||
unless --slow is passed or E2E_SLOW=1 is set.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
E2E_DIR = Path(__file__).parent.resolve()
|
||||
|
||||
|
||||
def is_slow(test_file: Path) -> bool:
|
||||
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
|
||||
with open(test_file) as f:
|
||||
for line in f:
|
||||
if line.strip().startswith("#"):
|
||||
continue
|
||||
if line.strip().startswith('"""') or line.strip().startswith("'''"):
|
||||
# Read into the docstring
|
||||
for doc_line in f:
|
||||
if "slow" in doc_line.lower() and doc_line.strip().startswith(
|
||||
"slow"
|
||||
):
|
||||
return True
|
||||
if '"""' in doc_line or "'''" in doc_line:
|
||||
break
|
||||
break
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
|
||||
test_files = sorted(E2E_DIR.glob("test_*.py"))
|
||||
if not test_files:
|
||||
print("No test files found")
|
||||
sys.exit(1)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
failures = []
|
||||
|
||||
for test_file in test_files:
|
||||
name = test_file.stem
|
||||
if is_slow(test_file) and not run_slow:
|
||||
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
print(f"=== {name} ===")
|
||||
result = subprocess.run([sys.executable, str(test_file)])
|
||||
if result.returncode == 0:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
failures.append(name)
|
||||
print()
|
||||
|
||||
total = passed + failed + skipped
|
||||
print("================================")
|
||||
print(
|
||||
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
|
||||
)
|
||||
|
||||
if failed:
|
||||
print(f"Failed: {' '.join(failures)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
e2e/snapshots/inference.json
Normal file
8
e2e/snapshots/inference.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"model": "mlx-community/Qwen3-0.6B-4bit",
|
||||
"seed": 42,
|
||||
"temperature": 0,
|
||||
"prompt": "What is 2+2? Reply with just the number.",
|
||||
"max_tokens": 32,
|
||||
"content": "<think>\nOkay, so I need to figure out what 2+2 is. Let me think. Well, if you add 2 and 2 together"
|
||||
}
|
||||
22
e2e/test_cluster_formation.py
Normal file
22
e2e/test_cluster_formation.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Test: Basic cluster formation.
|
||||
|
||||
Verifies two nodes discover each other, elect a master, and the API responds.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("cluster_formation") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
print("PASSED: cluster_formation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
82
e2e/test_inference_snapshot.py
Normal file
82
e2e/test_inference_snapshot.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Test: Deterministic inference output (snapshot test).
|
||||
slow
|
||||
|
||||
Sends a chat completion request with a fixed seed and temperature=0,
|
||||
then verifies the output matches a known-good snapshot. This ensures
|
||||
inference produces consistent results across runs.
|
||||
|
||||
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
|
||||
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = "What is 2+2? Reply with just the number."
|
||||
MAX_TOKENS = 32
|
||||
SNAPSHOT_FILE = Path(__file__).parent / "snapshots" / "inference.json"
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("inference_snapshot") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
# Launch the model instance (triggers download + placement)
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED}, temperature=0)...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
temperature=0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
# Load or create snapshot
|
||||
if SNAPSHOT_FILE.exists():
|
||||
snapshot = json.loads(SNAPSHOT_FILE.read_text())
|
||||
expected = snapshot["content"]
|
||||
assert content == expected, (
|
||||
f"Snapshot mismatch!\n"
|
||||
f" Expected: {expected!r}\n"
|
||||
f" Got: {content!r}\n"
|
||||
f" Delete {SNAPSHOT_FILE} to regenerate."
|
||||
)
|
||||
print(" Output matches snapshot")
|
||||
else:
|
||||
SNAPSHOT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
SNAPSHOT_FILE.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"temperature": 0,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"content": content,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
print(f" Snapshot created: {SNAPSHOT_FILE}")
|
||||
|
||||
print("PASSED: inference_snapshot")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
47
e2e/test_no_internet.py
Normal file
47
e2e/test_no_internet.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Test: Cluster works without internet access.
|
||||
|
||||
Verifies exo functions correctly when containers can talk to each other
|
||||
but cannot reach the internet. Uses iptables to block all outbound traffic
|
||||
except private subnets and multicast (for mDNS discovery).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster(
|
||||
"no_internet",
|
||||
overrides=["tests/no_internet/docker-compose.override.yml"],
|
||||
) as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
# Verify internet is actually blocked from inside the containers
|
||||
for node in ["exo-node-1", "exo-node-2"]:
|
||||
rc, _ = await cluster.exec(
|
||||
node,
|
||||
"curl",
|
||||
"-sf",
|
||||
"--max-time",
|
||||
"3",
|
||||
"https://huggingface.co",
|
||||
check=False,
|
||||
)
|
||||
assert rc != 0, f"{node} should not be able to reach the internet"
|
||||
print(f" {node}: internet correctly blocked")
|
||||
|
||||
# Verify exo detected no internet connectivity
|
||||
log = await cluster.logs()
|
||||
assert "Internet connectivity: False" in log, "exo should detect no internet"
|
||||
print(" exo correctly detected no internet connectivity")
|
||||
|
||||
print("PASSED: no_internet")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
32
e2e/tests/no_internet/docker-compose.override.yml
Normal file
32
e2e/tests/no_internet/docker-compose.override.yml
Normal file
@@ -0,0 +1,32 @@
|
||||
# Block all outbound internet traffic using iptables while preserving:
|
||||
# - Multicast (224.0.0.0/4) for mDNS peer discovery
|
||||
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
|
||||
# - Loopback (127/8)
|
||||
# Requires NET_ADMIN capability for iptables.
|
||||
services:
|
||||
exo-node-1:
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
entrypoint: ["/bin/sh", "-c"]
|
||||
command:
|
||||
- |
|
||||
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||
iptables -A OUTPUT -j REJECT
|
||||
exec .venv/bin/exo -v
|
||||
exo-node-2:
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
entrypoint: ["/bin/sh", "-c"]
|
||||
command:
|
||||
- |
|
||||
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||
iptables -A OUTPUT -j REJECT
|
||||
exec .venv/bin/exo -v
|
||||
@@ -56,8 +56,49 @@ class DownloadCoordinator:
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
# Per-model throttle for download progress events
|
||||
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
async def _download_progress_callback(
|
||||
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
model_id = callback_shard.model_card.model_id
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
if progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
if model_id in self.active_downloads:
|
||||
del self.active_downloads[model_id]
|
||||
self._last_progress_time.pop(model_id, None)
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - self._last_progress_time.get(model_id, 0.0)
|
||||
> throttle_interval_secs
|
||||
):
|
||||
ongoing = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=callback_shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=ongoing)
|
||||
)
|
||||
self._last_progress_time[model_id] = current_time()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
@@ -119,12 +160,12 @@ class DownloadCoordinator:
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Check if already downloading or complete
|
||||
# Check if already downloading, complete, or recently failed
|
||||
if model_id in self.download_status:
|
||||
status = self.download_status[model_id]
|
||||
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
|
||||
if isinstance(status, (DownloadOngoing, DownloadCompleted, DownloadFailed)):
|
||||
logger.debug(
|
||||
f"Download for {model_id} already in progress or complete, skipping"
|
||||
f"Download for {model_id} already in progress, complete, or failed, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -169,46 +210,6 @@ class DownloadCoordinator:
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal last_progress_time
|
||||
|
||||
if progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
# Clean up active download tracking
|
||||
if callback_shard.model_card.model_id in self.active_downloads:
|
||||
del self.active_downloads[callback_shard.model_card.model_id]
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
ongoing = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=callback_shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=ongoing)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
async def download_wrapper() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(shard)
|
||||
@@ -283,6 +284,12 @@ class DownloadCoordinator:
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
model_id = progress.shard.model_card.model_id
|
||||
|
||||
# Active downloads emit progress via the callback — don't overwrite
|
||||
if model_id in self.active_downloads:
|
||||
continue
|
||||
|
||||
if progress.status == "complete":
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
|
||||
@@ -79,6 +79,7 @@ def chat_request_to_text_generation(
|
||||
seed=request.seed,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
enable_thinking=request.enable_thinking,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
|
||||
@@ -31,6 +31,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
NodeRdmaCtlStatus,
|
||||
NodeThunderboltInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
@@ -47,7 +48,9 @@ from exo.utils.info_gatherer.info_gatherer import (
|
||||
MemoryUsage,
|
||||
MiscData,
|
||||
NodeConfig,
|
||||
NodeDiskUsage,
|
||||
NodeNetworkInterfaces,
|
||||
RdmaCtlStatus,
|
||||
StaticNodeInformation,
|
||||
ThunderboltBridgeInfo,
|
||||
)
|
||||
@@ -223,6 +226,9 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
node_memory = {
|
||||
key: value for key, value in state.node_memory.items() if key != event.node_id
|
||||
}
|
||||
node_disk = {
|
||||
key: value for key, value in state.node_disk.items() if key != event.node_id
|
||||
}
|
||||
node_system = {
|
||||
key: value for key, value in state.node_system.items() if key != event.node_id
|
||||
}
|
||||
@@ -239,6 +245,9 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
for key, value in state.node_thunderbolt_bridge.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_rdma_ctl = {
|
||||
key: value for key, value in state.node_rdma_ctl.items() if key != event.node_id
|
||||
}
|
||||
# Only recompute cycles if the leaving node had TB bridge enabled
|
||||
leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)
|
||||
leaving_node_had_tb_enabled = (
|
||||
@@ -256,10 +265,12 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"last_seen": last_seen,
|
||||
"node_identities": node_identities,
|
||||
"node_memory": node_memory,
|
||||
"node_disk": node_disk,
|
||||
"node_system": node_system,
|
||||
"node_network": node_network,
|
||||
"node_thunderbolt": node_thunderbolt,
|
||||
"node_thunderbolt_bridge": node_thunderbolt_bridge,
|
||||
"node_rdma_ctl": node_rdma_ctl,
|
||||
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
|
||||
}
|
||||
)
|
||||
@@ -288,6 +299,8 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
update["node_memory"] = {**state.node_memory, event.node_id: info.memory}
|
||||
case MemoryUsage():
|
||||
update["node_memory"] = {**state.node_memory, event.node_id: info}
|
||||
case NodeDiskUsage():
|
||||
update["node_disk"] = {**state.node_disk, event.node_id: info.disk_usage}
|
||||
case NodeConfig():
|
||||
pass
|
||||
case MiscData():
|
||||
@@ -302,7 +315,12 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
case StaticNodeInformation():
|
||||
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
|
||||
new_identity = current_identity.model_copy(
|
||||
update={"model_id": info.model, "chip_id": info.chip}
|
||||
update={
|
||||
"model_id": info.model,
|
||||
"chip_id": info.chip,
|
||||
"os_version": info.os_version,
|
||||
"os_build_version": info.os_build_version,
|
||||
}
|
||||
)
|
||||
update["node_identities"] = {
|
||||
**state.node_identities,
|
||||
@@ -354,6 +372,11 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
new_tb_bridge, state.node_network
|
||||
)
|
||||
)
|
||||
case RdmaCtlStatus():
|
||||
update["node_rdma_ctl"] = {
|
||||
**state.node_rdma_ctl,
|
||||
event.node_id: NodeRdmaCtlStatus(enabled=info.enabled),
|
||||
}
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
@@ -199,6 +199,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
enable_thinking: bool | None = None
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import shutil
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Literal, Self
|
||||
|
||||
import psutil
|
||||
@@ -38,6 +40,22 @@ class MemoryUsage(CamelCaseModel):
|
||||
)
|
||||
|
||||
|
||||
class DiskUsage(CamelCaseModel):
|
||||
"""Disk space usage for the models directory."""
|
||||
|
||||
total: Memory
|
||||
available: Memory
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path: Path) -> Self:
|
||||
"""Get disk usage stats for the partition containing path."""
|
||||
total, _used, free = shutil.disk_usage(path)
|
||||
return cls(
|
||||
total=Memory.from_bytes(total),
|
||||
available=Memory.from_bytes(free),
|
||||
)
|
||||
|
||||
|
||||
class SystemPerformanceProfile(CamelCaseModel):
|
||||
# TODO: flops_fp16: float
|
||||
|
||||
@@ -63,6 +81,8 @@ class NodeIdentity(CamelCaseModel):
|
||||
model_id: str = "Unknown"
|
||||
chip_id: str = "Unknown"
|
||||
friendly_name: str = "Unknown"
|
||||
os_version: str = "Unknown"
|
||||
os_build_version: str = "Unknown"
|
||||
|
||||
|
||||
class NodeNetworkInfo(CamelCaseModel):
|
||||
@@ -77,6 +97,12 @@ class NodeThunderboltInfo(CamelCaseModel):
|
||||
interfaces: Sequence[ThunderboltIdentifier] = []
|
||||
|
||||
|
||||
class NodeRdmaCtlStatus(CamelCaseModel):
|
||||
"""Whether RDMA is enabled on this node (via rdma_ctl)."""
|
||||
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ThunderboltBridgeStatus(CamelCaseModel):
|
||||
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
|
||||
|
||||
|
||||
@@ -8,9 +8,11 @@ from pydantic.alias_generators import to_camel
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import (
|
||||
DiskUsage,
|
||||
MemoryUsage,
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
NodeRdmaCtlStatus,
|
||||
NodeThunderboltInfo,
|
||||
SystemPerformanceProfile,
|
||||
ThunderboltBridgeStatus,
|
||||
@@ -49,10 +51,12 @@ class State(CamelCaseModel):
|
||||
# Granular node state mappings (update independently at different frequencies)
|
||||
node_identities: Mapping[NodeId, NodeIdentity] = {}
|
||||
node_memory: Mapping[NodeId, MemoryUsage] = {}
|
||||
node_disk: Mapping[NodeId, DiskUsage] = {}
|
||||
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
|
||||
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
|
||||
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
|
||||
node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] = {}
|
||||
|
||||
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
|
||||
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
|
||||
|
||||
@@ -40,5 +40,6 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
|
||||
stop: str | list[str] | None = None
|
||||
seed: int | None = None
|
||||
chat_template_messages: list[dict[str, Any]] | None = None
|
||||
enable_thinking: bool | None = None
|
||||
logprobs: bool = False
|
||||
top_logprobs: int | None = None
|
||||
|
||||
@@ -12,6 +12,7 @@ class ThunderboltConnection(CamelCaseModel):
|
||||
class ThunderboltIdentifier(CamelCaseModel):
|
||||
rdma_interface: str
|
||||
domain_uuid: str
|
||||
link_speed: str = ""
|
||||
|
||||
|
||||
## Intentionally minimal, only collecting data we care about - there's a lot more
|
||||
@@ -19,6 +20,7 @@ class ThunderboltIdentifier(CamelCaseModel):
|
||||
|
||||
class _ReceptacleTag(BaseModel, extra="ignore"):
|
||||
receptacle_id_key: str | None = None
|
||||
current_speed_key: str | None = None
|
||||
|
||||
|
||||
class _ConnectivityItem(BaseModel, extra="ignore"):
|
||||
@@ -42,7 +44,9 @@ class ThunderboltConnectivityData(BaseModel, extra="ignore"):
|
||||
# if tag not in ifaces: return None
|
||||
iface = f"rdma_{ifaces[tag]}"
|
||||
return ThunderboltIdentifier(
|
||||
rdma_interface=iface, domain_uuid=self.domain_uuid_key
|
||||
rdma_interface=iface,
|
||||
domain_uuid=self.domain_uuid_key,
|
||||
link_speed=self.receptacle_1_tag.current_speed_key or "",
|
||||
)
|
||||
|
||||
def conn(self) -> ThunderboltConnection | None:
|
||||
|
||||
@@ -8,16 +8,17 @@ from subprocess import CalledProcessError
|
||||
from typing import Self, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group, open_process
|
||||
from anyio import create_task_group, fail_after, open_process, to_thread
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_CONFIG_FILE
|
||||
from exo.shared.constants import EXO_CONFIG_FILE, EXO_MODELS_DIR
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
DiskUsage,
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
@@ -31,7 +32,13 @@ from exo.utils.channels import Sender
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .macmon import MacmonMetrics
|
||||
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
|
||||
from .system_info import (
|
||||
get_friendly_name,
|
||||
get_model_and_chip,
|
||||
get_network_interfaces,
|
||||
get_os_build_version,
|
||||
get_os_version,
|
||||
)
|
||||
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
@@ -177,11 +184,18 @@ class StaticNodeInformation(TaggedModel):
|
||||
|
||||
model: str
|
||||
chip: str
|
||||
os_version: str
|
||||
os_build_version: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
model, chip = await get_model_and_chip()
|
||||
return cls(model=model, chip=chip)
|
||||
return cls(
|
||||
model=model,
|
||||
chip=chip,
|
||||
os_version=get_os_version(),
|
||||
os_build_version=await get_os_build_version(),
|
||||
)
|
||||
|
||||
|
||||
class NodeNetworkInterfaces(TaggedModel):
|
||||
@@ -196,6 +210,28 @@ class MacThunderboltConnections(TaggedModel):
|
||||
conns: Sequence[ThunderboltConnection]
|
||||
|
||||
|
||||
class RdmaCtlStatus(TaggedModel):
|
||||
enabled: bool
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
if not IS_DARWIN or shutil.which("rdma_ctl") is None:
|
||||
return None
|
||||
try:
|
||||
with anyio.fail_after(5):
|
||||
proc = await anyio.run_process(["rdma_ctl", "status"], check=False)
|
||||
except (TimeoutError, OSError):
|
||||
return None
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
output = proc.stdout.decode("utf-8").lower().strip()
|
||||
if "enabled" in output:
|
||||
return cls(enabled=True)
|
||||
if "disabled" in output:
|
||||
return cls(enabled=False)
|
||||
return None
|
||||
|
||||
|
||||
class ThunderboltBridgeInfo(TaggedModel):
|
||||
status: ThunderboltBridgeStatus
|
||||
|
||||
@@ -284,6 +320,20 @@ class MiscData(TaggedModel):
|
||||
return cls(friendly_name=await get_friendly_name())
|
||||
|
||||
|
||||
class NodeDiskUsage(TaggedModel):
|
||||
"""Disk space information for the models directory."""
|
||||
|
||||
disk_usage: DiskUsage
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
return cls(
|
||||
disk_usage=await to_thread.run_sync(
|
||||
lambda: DiskUsage.from_path(EXO_MODELS_DIR)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _gather_iface_map() -> dict[str, str] | None:
|
||||
proc = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"], check=False
|
||||
@@ -310,10 +360,12 @@ GatheredInfo = (
|
||||
| NodeNetworkInterfaces
|
||||
| MacThunderboltIdentifiers
|
||||
| MacThunderboltConnections
|
||||
| RdmaCtlStatus
|
||||
| ThunderboltBridgeInfo
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
| NodeDiskUsage
|
||||
)
|
||||
|
||||
|
||||
@@ -326,6 +378,9 @@ class InfoGatherer:
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
|
||||
static_info_poll_interval: float | None = 60
|
||||
rdma_ctl_poll_interval: float | None = 10 if IS_DARWIN else None
|
||||
disk_poll_interval: float | None = 30
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
@@ -335,25 +390,38 @@ class InfoGatherer:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
|
||||
tg.start_soon(self._monitor_thunderbolt_bridge_status)
|
||||
tg.start_soon(self._monitor_rdma_ctl_status)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
tg.start_soon(self._monitor_static_info)
|
||||
tg.start_soon(self._monitor_disk_usage)
|
||||
|
||||
nc = await NodeConfig.gather()
|
||||
if nc is not None:
|
||||
await self.info_sender.send(nc)
|
||||
sni = await StaticNodeInformation.gather()
|
||||
await self.info_sender.send(sni)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _monitor_static_info(self):
|
||||
if self.static_info_poll_interval is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
with fail_after(30):
|
||||
await self.info_sender.send(await StaticNodeInformation.gather())
|
||||
except Exception as e:
|
||||
logger.warning(f"Error gathering static node info: {e}")
|
||||
await anyio.sleep(self.static_info_poll_interval)
|
||||
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
await self.info_sender.send(await MiscData.gather())
|
||||
with fail_after(10):
|
||||
await self.info_sender.send(await MiscData.gather())
|
||||
except Exception as e:
|
||||
logger.warning(f"Error gathering misc data: {e}")
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
@@ -361,20 +429,26 @@ class InfoGatherer:
|
||||
async def _monitor_system_profiler_thunderbolt_data(self):
|
||||
if self.system_profiler_interval is None:
|
||||
return
|
||||
iface_map = await _gather_iface_map()
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
assert data is not None
|
||||
with fail_after(30):
|
||||
iface_map = await _gather_iface_map()
|
||||
if iface_map is None:
|
||||
raise ValueError("Failed to gather interface map")
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacThunderboltConnections(conns=conns))
|
||||
idents = [
|
||||
it for i in data if (it := i.ident(iface_map)) is not None
|
||||
]
|
||||
await self.info_sender.send(
|
||||
MacThunderboltIdentifiers(idents=idents)
|
||||
)
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacThunderboltConnections(conns=conns))
|
||||
except Exception as e:
|
||||
logger.warning(f"Error gathering Thunderbolt data: {e}")
|
||||
await anyio.sleep(self.system_profiler_interval)
|
||||
@@ -402,8 +476,9 @@ class InfoGatherer:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
nics = await get_network_interfaces()
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
with fail_after(10):
|
||||
nics = await get_network_interfaces()
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
except Exception as e:
|
||||
logger.warning(f"Error gathering network interfaces: {e}")
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
@@ -413,37 +488,70 @@ class InfoGatherer:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None:
|
||||
await self.info_sender.send(curr)
|
||||
with fail_after(30):
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None:
|
||||
await self.info_sender.send(curr)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error gathering Thunderbolt Bridge status: {e}")
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
async def _monitor_rdma_ctl_status(self):
|
||||
if self.rdma_ctl_poll_interval is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
curr = await RdmaCtlStatus.gather()
|
||||
if curr is not None:
|
||||
await self.info_sender.send(curr)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error gathering RDMA ctl status: {e}")
|
||||
await anyio.sleep(self.rdma_ctl_poll_interval)
|
||||
|
||||
async def _monitor_disk_usage(self):
|
||||
if self.disk_poll_interval is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
with fail_after(5):
|
||||
await self.info_sender.send(await NodeDiskUsage.gather())
|
||||
except Exception as e:
|
||||
logger.warning(f"Error gathering disk usage: {e}")
|
||||
await anyio.sleep(self.disk_poll_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
# macmon pipe --interval [interval in ms]
|
||||
try:
|
||||
async with await open_process(
|
||||
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
|
||||
) as p:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
if stderr_output is not None:
|
||||
stderr_msg = (
|
||||
stderr_output.decode()
|
||||
if isinstance(stderr_output, bytes)
|
||||
else str(stderr_output)
|
||||
while True:
|
||||
try:
|
||||
async with await open_process(
|
||||
[
|
||||
macmon_path,
|
||||
"pipe",
|
||||
"--interval",
|
||||
str(self.macmon_interval * 1000),
|
||||
]
|
||||
) as p:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
if stderr_output is not None:
|
||||
stderr_msg = (
|
||||
stderr_output.decode()
|
||||
if isinstance(stderr_output, bytes)
|
||||
else str(stderr_output)
|
||||
)
|
||||
logger.warning(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
logger.warning(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in macmon monitor: {e}")
|
||||
await anyio.sleep(self.macmon_interval)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
@@ -8,6 +9,34 @@ from anyio import run_process
|
||||
from exo.shared.types.profiling import InterfaceType, NetworkInterfaceInfo
|
||||
|
||||
|
||||
def get_os_version() -> str:
|
||||
"""Return the OS version string for this node.
|
||||
|
||||
On macOS this is the macOS version (e.g. ``"15.3"``).
|
||||
On other platforms it falls back to the platform name (e.g. ``"Linux"``).
|
||||
"""
|
||||
if sys.platform == "darwin":
|
||||
version = platform.mac_ver()[0]
|
||||
return version if version else "Unknown"
|
||||
return platform.system() or "Unknown"
|
||||
|
||||
|
||||
async def get_os_build_version() -> str:
|
||||
"""Return the macOS build version string (e.g. ``"24D5055b"``).
|
||||
|
||||
On non-macOS platforms, returns ``"Unknown"``.
|
||||
"""
|
||||
if sys.platform != "darwin":
|
||||
return "Unknown"
|
||||
|
||||
try:
|
||||
process = await run_process(["sw_vers", "-buildVersion"])
|
||||
except CalledProcessError:
|
||||
return "Unknown"
|
||||
|
||||
return process.stdout.decode("utf-8", errors="replace").strip() or "Unknown"
|
||||
|
||||
|
||||
async def get_friendly_name() -> str:
|
||||
"""
|
||||
Asynchronously gets the 'Computer Name' (friendly name) of a Mac.
|
||||
|
||||
@@ -462,11 +462,19 @@ def apply_chat_template(
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if task_params.enable_thinking is not None:
|
||||
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
|
||||
# Jinja ignores unknown variables, so passing both is safe.
|
||||
extra_kwargs["enable_thinking"] = task_params.enable_thinking
|
||||
extra_kwargs["thinking"] = task_params.enable_thinking
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=task_params.tools,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if partial_assistant_content:
|
||||
|
||||
@@ -184,6 +184,14 @@ class Worker:
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
|
||||
# Gate DownloadModel on backoff BEFORE emitting TaskCreated
|
||||
# to prevent flooding the event log with useless events
|
||||
if isinstance(task, DownloadModel):
|
||||
model_id = task.shard_metadata.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
continue
|
||||
|
||||
logger.info(f"Worker plan: {task.__class__.__name__}")
|
||||
assert task.task_status
|
||||
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
|
||||
@@ -199,9 +207,6 @@ class Worker:
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
model_id = shard.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
continue
|
||||
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
|
||||
await self.download_command_sender.send(
|
||||
|
||||
Reference in New Issue
Block a user