30 Commits

Author SHA1 Message Date
Evan
7d2e828aba stop pinging loopback addresses 2025-12-27 12:17:35 +00:00
Evan
b5319d6b03 switch from sequence to map of connections 2025-12-27 12:00:22 +00:00
Evan
b988e08d69 pydantic types are now coherent 2025-12-27 11:20:49 +00:00
Evan
9bf5979f8a rebase fix 2025-12-27 01:04:12 +00:00
Sami Khan
91944383d3 parsing api fix 2025-12-27 01:04:12 +00:00
Evan
dcc6872724 code review followup 2025-12-27 01:04:12 +00:00
Evan
dccc2709c5 rename channel test 2025-12-24 19:52:52 +00:00
Evan
20d1246600 move macmon test 2025-12-24 19:52:52 +00:00
Evan
81bad9e01a cleanup after rebase 2025-12-24 19:52:52 +00:00
Evan
7ff67d0a28 dedup connections 2025-12-24 19:52:52 +00:00
Evan
c888b13d3f freeze those models 2025-12-24 19:52:52 +00:00
Evan
1f80705b56 format 2025-12-24 19:52:52 +00:00
Evan
b349330404 tidy 2025-12-24 19:52:52 +00:00
Evan
812ce47194 all mastet tests pass 2025-12-24 19:52:52 +00:00
Evan
643c6b8d28 ibv -> jaccl 2025-12-24 19:52:52 +00:00
Evan
4754f56bd4 tidying some horrible logic 2025-12-24 19:51:50 +00:00
Evan
66d01369b4 fix download test 2025-12-24 19:51:50 +00:00
Evan
d20d9e5fc8 fix all master tests except rdma placement 2025-12-24 19:51:50 +00:00
Evan
e67282282c fix topology tests 2025-12-24 19:51:33 +00:00
Evan
54daa9e2db bug 2025-12-24 19:51:33 +00:00
Evan
06125d1503 actually update the topology 2025-12-24 19:51:33 +00:00
Evan
505e756872 incorrect log 2025-12-24 19:51:33 +00:00
Evan
4cd3db0f6e handle an error 2025-12-24 19:51:33 +00:00
Evan
8b137a1e64 fix pydantic validation 2025-12-24 19:51:33 +00:00
Evan
4176c7ec25 type checks outside of tests, time to test 2025-12-24 19:51:33 +00:00
Evan
dbce607911 wuff 2025-12-24 19:51:33 +00:00
Evan
9949b93517 rework topology 2025-12-24 19:51:33 +00:00
Evan
f4feeff077 update placement 2025-12-24 19:51:33 +00:00
Evan
f529884344 mvp 2025-12-24 19:50:31 +00:00
Evan
df4c6ce24e tidy config 2025-12-24 19:50:31 +00:00
79 changed files with 1995 additions and 4105 deletions

2
.gitignore vendored
View File

@@ -7,8 +7,6 @@ digest.txt
# nix
.direnv/
# IDEA (PyCharm)
.idea
# xcode / macos
*.xcuserstate

View File

@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
There are two ways to run exo:
### Run from Source (macOS)
### Run from Source (Mac & Linux)
**Prerequisites:**
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
@@ -98,62 +98,6 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
### Run from Source (Linux)
**Prerequisites:**
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
**Installation methods:**
**Option 1: Using system package manager (Ubuntu/Debian example):**
```bash
# Install Node.js and npm
sudo apt update
sudo apt install nodejs npm
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# Install Rust (using rustup)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
**Option 2: Using Homebrew on Linux (if preferred):**
```bash
# Install Homebrew on Linux
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
# Install dependencies
brew install uv node
# Install Rust (using rustup)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
**Note:** The `macmon` package is macOS-only and not required for Linux.
Clone the repo, build the dashboard, and run exo:
```bash
# Clone exo
git clone https://github.com/exo-explore/exo
# Build dashboard
cd exo/dashboard && npm install && npm run build && cd ..
# Run exo
uv run exo
```
This starts the exo dashboard and API at http://localhost:52415/
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
### macOS App
exo ships a macOS app that runs in the background on your Mac.
@@ -168,29 +112,6 @@ The app will ask for permission to modify system settings and install a new Netw
---
### Enabling RDMA on macOS
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
To enable RDMA on macOS, follow these steps:
1. Shut down your Mac.
2. Hold down the power button for 10 seconds until the boot menu appears.
3. Select "Options" to enter Recovery mode.
4. When the Recovery UI appears, open the Terminal from the Utilities menu.
5. In the Terminal, type:
```
rdma_ctl enable
```
and press Enter.
6. Reboot your Mac.
After that, RDMA will be enabled in macOS and exo will take care of the rest.
---
### Using the API
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.

View File

@@ -19,6 +19,7 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -49,7 +49,7 @@ struct ContentView: View {
private var topologySection: some View {
Group {
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
if let topology = stateService.latestSnapshot?.topologyViewModel(), !topology.nodes.isEmpty {
TopologyMiniView(topology: topology)
}
}

View File

@@ -82,6 +82,7 @@ struct BugReportService {
}
private func loadCredentials() throws -> AWSConfig {
// These credentials are write-only and necessary to receive bug reports from users
return AWSConfig(
accessKey: "AKIAYEKP5EMXTOBYDGHX",
secretKey: "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",

View File

@@ -7,7 +7,6 @@ final class ClusterStateService: ObservableObject {
@Published private(set) var lastError: String?
@Published private(set) var lastActionMessage: String?
@Published private(set) var modelOptions: [ModelOption] = []
@Published private(set) var localNodeId: String?
private var timer: Timer?
private let decoder: JSONDecoder
@@ -30,7 +29,6 @@ final class ClusterStateService: ObservableObject {
func startPolling(interval: TimeInterval = 0.5) {
stopPolling()
Task {
await fetchLocalNodeId()
await fetchModels()
await fetchSnapshot()
}
@@ -48,31 +46,9 @@ final class ClusterStateService: ObservableObject {
latestSnapshot = nil
lastError = nil
lastActionMessage = nil
localNodeId = nil
}
private func fetchLocalNodeId() async {
do {
let url = baseURL.appendingPathComponent("node_id")
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
return
}
if let nodeId = try? decoder.decode(String.self, from: data) {
localNodeId = nodeId
}
} catch {
// Silently ignore - localNodeId will remain nil and retry on next poll
}
}
private func fetchSnapshot() async {
// Retry fetching local node ID if not yet set
if localNodeId == nil {
await fetchLocalNodeId()
}
do {
var request = URLRequest(url: endpoint)
request.cachePolicy = .reloadIgnoringLocalCacheData

View File

@@ -85,7 +85,7 @@ struct TopologyViewModel {
}
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
func topologyViewModel() -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
guard !allNodes.isEmpty else { return nil }
@@ -105,11 +105,6 @@ extension ClusterState {
orderedNodes = allNodes
}
// Rotate so the local node (from /node_id API) is first
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
}
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
@@ -117,7 +112,10 @@ extension ClusterState {
} ?? []
let edges = Set(edgesArray)
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
let topologyRootId = topology?.nodes.first?.nodeId
let currentId = orderedNodes.first(where: { $0.id == topologyRootId })?.id ?? orderedNodes.first?.id
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: currentId)
}
}

View File

@@ -9,8 +9,6 @@
"version": "1.0.0",
"dependencies": {
"highlight.js": "^11.11.1",
"katex": "^0.16.27",
"marked": "^17.0.1",
"mode-watcher": "^1.1.0"
},
"devDependencies": {
@@ -863,6 +861,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -902,6 +901,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,6 +1518,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1527,6 +1528,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1939,6 +1941,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2251,31 +2254,6 @@
"jiti": "lib/jiti-cli.mjs"
}
},
"node_modules/katex": {
"version": "0.16.27",
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.27.tgz",
"integrity": "sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==",
"funding": [
"https://opencollective.com/katex",
"https://github.com/sponsors/katex"
],
"license": "MIT",
"dependencies": {
"commander": "^8.3.0"
},
"bin": {
"katex": "cli.js"
}
},
"node_modules/katex/node_modules/commander": {
"version": "8.3.0",
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
"license": "MIT",
"engines": {
"node": ">= 12"
}
},
"node_modules/kleur": {
"version": "4.1.5",
"resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz",
@@ -2562,18 +2540,6 @@
"@jridgewell/sourcemap-codec": "^1.5.5"
}
},
"node_modules/marked": {
"version": "17.0.1",
"resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz",
"integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==",
"license": "MIT",
"bin": {
"marked": "bin/marked.js"
},
"engines": {
"node": ">= 20"
}
},
"node_modules/mode-watcher": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz",
@@ -2646,6 +2612,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2833,6 +2800,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2977,6 +2945,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2998,6 +2967,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -27,8 +27,7 @@
},
"dependencies": {
"highlight.js": "^11.11.1",
"katex": "^0.16.27",
"marked": "^17.0.1",
"mode-watcher": "^1.1.0"
}
}

View File

@@ -198,10 +198,8 @@
stroke: oklch(0.85 0.18 85 / 0.4);
stroke-width: 1.5px;
stroke-dasharray: 8, 8;
animation: flowAnimation 1.5s linear infinite;
animation: flowAnimation 1s linear infinite;
filter: drop-shadow(0 0 3px oklch(0.85 0.18 85 / 0.5));
/* GPU optimization - hint to browser this element will animate */
will-change: stroke-dashoffset;
}
.graph-link-active {
@@ -210,24 +208,6 @@
filter: drop-shadow(0 0 6px oklch(0.85 0.18 85 / 0.8));
}
/* Reduce motion for users who prefer it - also saves GPU */
@media (prefers-reduced-motion: reduce) {
.graph-link {
animation: none;
}
.shooting-star {
animation: none;
display: none;
}
.status-pulse,
.cursor-blink,
.animate-pulse {
animation: none;
}
}
/* CRT Screen effect for topology */
.crt-screen {
position: relative;
@@ -286,15 +266,13 @@ input:focus, textarea:focus {
box-shadow: none;
}
/* Shooting Stars Animation - GPU optimized */
/* Shooting Stars Animation */
.shooting-stars {
position: fixed;
inset: 0;
overflow: hidden;
pointer-events: none;
z-index: 0;
/* Only render when visible */
content-visibility: auto;
}
.shooting-star {
@@ -307,9 +285,6 @@ input:focus, textarea:focus {
animation: shootingStar var(--duration, 3s) linear infinite;
animation-delay: var(--delay, 0s);
opacity: 0;
/* GPU optimization */
will-change: transform, opacity;
transform: translateZ(0);
}
.shooting-star::before {
@@ -345,13 +320,3 @@ input:focus, textarea:focus {
transform: translate(400px, 400px);
}
}
/* Pause animations when page is hidden to save resources */
:root:has(body[data-page-hidden="true"]) {
.shooting-star,
.graph-link,
.status-pulse,
.cursor-blink {
animation-play-state: paused;
}
}

View File

@@ -8,80 +8,89 @@
regenerateLastResponse
} from '$lib/stores/app.svelte';
import type { MessageAttachment } from '$lib/stores/app.svelte';
import MarkdownContent from './MarkdownContent.svelte';
import { tick, onDestroy } from 'svelte';
interface Props {
class?: string;
scrollParent?: HTMLElement | null;
}
interface Props {
class?: string;
scrollParent?: HTMLElement | null;
}
let { class: className = '', scrollParent = null }: Props = $props();
let { class: className = '', scrollParent = null }: Props = $props();
const messageList = $derived(messages());
const response = $derived(currentResponse());
const loading = $derived(isLoading());
// Scroll management - user controls scroll, show button when not at bottom
const SCROLL_THRESHOLD = 100;
let showScrollButton = $state(false);
let lastMessageCount = 0;
let containerRef: HTMLDivElement | undefined = $state();
// Ref for scroll anchor at bottom
let scrollAnchorRef: HTMLDivElement | undefined = $state();
function getScrollContainer(): HTMLElement | null {
if (scrollParent) return scrollParent;
return containerRef?.parentElement ?? null;
// Scroll management
const SCROLL_BOTTOM_THRESHOLD = 120;
let autoScrollEnabled = true;
let currentScrollEl: HTMLElement | null = null;
function resolveScrollElement(): HTMLElement | null {
if (scrollParent) return scrollParent;
let node: HTMLElement | null = scrollAnchorRef?.parentElement as HTMLElement | null;
while (node) {
const isScrollable = node.scrollHeight > node.clientHeight + 1;
if (isScrollable) return node;
node = node.parentElement;
}
return null;
}
function isNearBottom(el: HTMLElement): boolean {
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
function handleScroll() {
if (!currentScrollEl) return;
const distanceFromBottom = currentScrollEl.scrollHeight - currentScrollEl.scrollTop - currentScrollEl.clientHeight;
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
autoScrollEnabled = isNearBottom;
}
function attachScrollListener() {
const nextEl = resolveScrollElement();
if (currentScrollEl === nextEl) return;
if (currentScrollEl) {
currentScrollEl.removeEventListener('scroll', handleScroll);
}
currentScrollEl = nextEl;
if (currentScrollEl) {
currentScrollEl.addEventListener('scroll', handleScroll);
// Initialize state based on current position
handleScroll();
}
}
function scrollToBottom() {
const el = getScrollContainer();
if (el) {
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
onDestroy(() => {
if (currentScrollEl) {
currentScrollEl.removeEventListener('scroll', handleScroll);
}
});
$effect(() => {
// Re-evaluate scroll container if prop changes or after mount
scrollParent;
attachScrollListener();
});
// Auto-scroll to bottom when messages change or response updates, but only if user is near bottom
$effect(() => {
// Track these values to trigger effect
const _ = messageList.length;
const __ = response;
const ___ = loading;
tick().then(() => {
const el = currentScrollEl ?? resolveScrollElement();
if (!el || !scrollAnchorRef) return;
const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
if (autoScrollEnabled || isNearBottom) {
scrollAnchorRef.scrollIntoView({ behavior: 'smooth', block: 'end' });
autoScrollEnabled = true;
}
}
function updateScrollButtonVisibility() {
const el = getScrollContainer();
if (!el) return;
showScrollButton = !isNearBottom(el);
}
// Attach scroll listener
$effect(() => {
const el = scrollParent ?? containerRef?.parentElement;
if (!el) return;
el.addEventListener('scroll', updateScrollButtonVisibility, { passive: true });
// Initial check
updateScrollButtonVisibility();
return () => el.removeEventListener('scroll', updateScrollButtonVisibility);
});
// Auto-scroll when user sends a new message
$effect(() => {
const count = messageList.length;
if (count > lastMessageCount) {
const el = getScrollContainer();
if (el) {
requestAnimationFrame(() => {
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
});
}
}
lastMessageCount = count;
});
// Update scroll button visibility when content changes
$effect(() => {
// Track response to trigger re-check during streaming
const _ = response;
// Small delay to let DOM update
requestAnimationFrame(() => updateScrollButtonVisibility());
});
});
// Edit state
let editingMessageId = $state<string | null>(null);
@@ -222,7 +231,7 @@ function isThinkingExpanded(messageId: string): boolean {
<div class="flex flex-col gap-4 sm:gap-6 {className}">
{#each messageList as message (message.id)}
<div class="group flex {message.role === 'user' ? 'justify-end' : 'justify-start'}">
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'w-full max-w-[98%] sm:max-w-[95%]'}">
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'max-w-[95%] sm:max-w-[85%]'}">
{#if message.role === 'assistant'}
<!-- Assistant message header -->
<div class="flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2">
@@ -296,7 +305,7 @@ function isThinkingExpanded(messageId: string): boolean {
{:else}
<div class="{message.role === 'user'
? 'command-panel rounded-lg rounded-tr-sm inline-block'
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full'}">
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 inline-block'}">
{#if message.role === 'user'}
<!-- User message styling -->
@@ -322,7 +331,7 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
{#if message.content}
<div class="text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
{message.content}
</div>
{/if}
@@ -351,7 +360,7 @@ function isThinkingExpanded(messageId: string): boolean {
</svg>
<span>Thinking...</span>
</span>
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4">
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60">
{isThinkingExpanded(message.id) ? 'HIDE' : 'SHOW'}
</span>
</button>
@@ -365,8 +374,8 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
{message.content || (loading ? response : '')}
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
@@ -448,20 +457,6 @@ function isThinkingExpanded(messageId: string): boolean {
</div>
{/if}
<!-- Invisible element for container reference -->
<div bind:this={containerRef}></div>
<!-- Scroll to bottom button -->
{#if showScrollButton}
<button
type="button"
onclick={scrollToBottom}
class="sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10"
title="Scroll to bottom"
>
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
</svg>
</button>
{/if}
<!-- Scroll anchor for auto-scroll -->
<div bind:this={scrollAnchorRef}></div>
</div>

View File

@@ -10,9 +10,7 @@ import {
clearChat,
instances,
debugMode,
toggleDebugMode,
topologyOnlyMode,
toggleTopologyOnlyMode
toggleDebugMode
} from '$lib/stores/app.svelte';
interface Props {
@@ -25,7 +23,6 @@ import {
const activeId = $derived(activeConversationId());
const instanceData = $derived(instances());
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
let searchQuery = $state('');
let editingId = $state<string | null>(null);
@@ -427,19 +424,6 @@ const topologyOnlyEnabled = $derived(topologyOnlyMode());
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
</div>
<button
type="button"
onclick={toggleTopologyOnlyMode}
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
title="Toggle topology only mode"
>
<svg class="w-4 h-4 {topologyOnlyEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="5" r="2" fill="currentColor" />
<circle cx="5" cy="19" r="2" fill="currentColor" />
<circle cx="19" cy="19" r="2" fill="currentColor" />
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
</svg>
</button>
</div>
</div>
</aside>

View File

@@ -3,9 +3,6 @@
export let showHome = true;
export let onHome: (() => void) | null = null;
export let showSidebarToggle = false;
export let sidebarVisible = true;
export let onToggleSidebar: (() => void) | null = null;
function handleHome(): void {
if (onHome) {
@@ -17,38 +14,13 @@
window.location.hash = '/';
}
}
function handleToggleSidebar(): void {
if (onToggleSidebar) {
onToggleSidebar();
}
}
</script>
<header class="relative z-20 flex items-center justify-center px-6 pt-8 pb-4 bg-exo-dark-gray">
<!-- Left: Sidebar Toggle -->
{#if showSidebarToggle}
<div class="absolute left-6 top-1/2 -translate-y-1/2">
<button
onclick={handleToggleSidebar}
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
title={sidebarVisible ? 'Hide sidebar' : 'Show sidebar'}
>
<svg class="w-5 h-5 {sidebarVisible ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
{#if sidebarVisible}
<path stroke-linecap="round" stroke-linejoin="round" d="M11 19l-7-7 7-7m8 14l-7-7 7-7" />
{:else}
<path stroke-linecap="round" stroke-linejoin="round" d="M13 5l7 7-7 7M5 5l7 7-7 7" />
{/if}
</svg>
</button>
</div>
{/if}
<!-- Center: Logo (clickable to go home) -->
<button
onclick={handleHome}
class="bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome ? 'cursor-pointer' : 'cursor-default'}"
class="hover:opacity-80 transition-opacity {showHome ? 'cursor-pointer' : 'cursor-default'}"
title={showHome ? 'Go to home' : ''}
disabled={!showHome}
>

View File

@@ -1,451 +0,0 @@
<script lang="ts">
import { marked } from 'marked';
import hljs from 'highlight.js';
import katex from 'katex';
import 'katex/dist/katex.min.css';
import { browser } from '$app/environment';
interface Props {
content: string;
class?: string;
}
let { content, class: className = '' }: Props = $props();
let containerRef = $state<HTMLDivElement>();
let processedHtml = $state('');
// Configure marked with syntax highlighting
marked.setOptions({
gfm: true,
breaks: true
});
// Custom renderer for code blocks
const renderer = new marked.Renderer();
renderer.code = function ({ text, lang }: { text: string; lang?: string }) {
const language = lang && hljs.getLanguage(lang) ? lang : 'plaintext';
const highlighted = hljs.highlight(text, { language }).value;
const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
return `
<div class="code-block-wrapper">
<div class="code-block-header">
<span class="code-language">${language}</span>
<button type="button" class="copy-code-btn" data-code="${encodeURIComponent(text)}" title="Copy code">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<pre><code class="hljs language-${language}" data-code-id="${codeId}">${highlighted}</code></pre>
</div>
`;
};
// Inline code
renderer.codespan = function ({ text }: { text: string }) {
return `<code class="inline-code">${text}</code>`;
};
marked.use({ renderer });
/**
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
*/
function preprocessLaTeX(text: string): string {
// Protect code blocks
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `<<CODE_${codeBlocks.length - 1}>>`;
});
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Restore code blocks
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
return processed;
}
/**
* Render math expressions with KaTeX after HTML is generated
*/
function renderMath(html: string): string {
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
return html;
}
function processMarkdown(text: string): string {
try {
// Preprocess LaTeX notation
const preprocessed = preprocessLaTeX(text);
// Parse markdown
let html = marked.parse(preprocessed) as string;
// Render math expressions
html = renderMath(html);
return html;
} catch (error) {
console.error('Markdown processing error:', error);
return text.replace(/\n/g, '<br>');
}
}
async function handleCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedCode = target.getAttribute('data-code');
if (!encodedCode) return;
const code = decodeURIComponent(encodedCode);
try {
await navigator.clipboard.writeText(code);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
}
$effect(() => {
if (content) {
processedHtml = processMarkdown(content);
} else {
processedHtml = '';
}
});
$effect(() => {
if (containerRef && processedHtml) {
setupCopyButtons();
}
});
</script>
<div bind:this={containerRef} class="markdown-content {className}">
{@html processedHtml}
</div>
<style>
.markdown-content {
line-height: 1.6;
}
/* Paragraphs */
.markdown-content :global(p) {
margin-bottom: 1rem;
}
.markdown-content :global(p:last-child) {
margin-bottom: 0;
}
/* Headers */
.markdown-content :global(h1) {
font-size: 1.5rem;
font-weight: 700;
margin: 1.5rem 0 0.75rem 0;
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(h2) {
font-size: 1.25rem;
font-weight: 600;
margin: 1.25rem 0 0.5rem 0;
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(h3) {
font-size: 1.125rem;
font-weight: 600;
margin: 1rem 0 0.5rem 0;
}
.markdown-content :global(h4),
.markdown-content :global(h5),
.markdown-content :global(h6) {
font-size: 1rem;
font-weight: 600;
margin: 0.75rem 0 0.25rem 0;
}
/* Bold and italic */
.markdown-content :global(strong) {
font-weight: 600;
}
.markdown-content :global(em) {
font-style: italic;
}
/* Inline code */
.markdown-content :global(.inline-code) {
background: rgba(255, 215, 0, 0.1);
color: var(--exo-yellow, #ffd700);
padding: 0.125rem 0.375rem;
border-radius: 0.25rem;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
}
/* Links */
.markdown-content :global(a) {
color: var(--exo-yellow, #ffd700);
text-decoration: underline;
text-underline-offset: 2px;
}
.markdown-content :global(a:hover) {
opacity: 0.8;
}
/* Lists */
.markdown-content :global(ul) {
list-style-type: disc;
margin-left: 1.5rem;
margin-bottom: 1rem;
}
.markdown-content :global(ol) {
list-style-type: decimal;
margin-left: 1.5rem;
margin-bottom: 1rem;
}
.markdown-content :global(li) {
margin-bottom: 0.25rem;
}
.markdown-content :global(li::marker) {
color: var(--exo-light-gray, #9ca3af);
}
/* Blockquotes */
.markdown-content :global(blockquote) {
border-left: 3px solid var(--exo-yellow, #ffd700);
padding: 0.5rem 1rem;
margin: 1rem 0;
background: rgba(255, 215, 0, 0.05);
border-radius: 0 0.25rem 0.25rem 0;
}
/* Tables */
.markdown-content :global(table) {
width: 100%;
margin: 1rem 0;
border-collapse: collapse;
font-size: 0.875rem;
}
.markdown-content :global(th) {
background: rgba(255, 215, 0, 0.1);
border: 1px solid rgba(255, 215, 0, 0.2);
padding: 0.5rem;
text-align: left;
font-weight: 600;
}
.markdown-content :global(td) {
border: 1px solid rgba(255, 255, 255, 0.1);
padding: 0.5rem;
}
/* Horizontal rule */
.markdown-content :global(hr) {
border: none;
border-top: 1px solid rgba(255, 255, 255, 0.1);
margin: 1.5rem 0;
}
/* Code block wrapper */
.markdown-content :global(.code-block-wrapper) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.2);
background: rgba(0, 0, 0, 0.4);
}
.markdown-content :global(.code-block-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.5rem 0.75rem;
background: rgba(255, 215, 0, 0.05);
border-bottom: 1px solid rgba(255, 215, 0, 0.1);
}
.markdown-content :global(.code-language) {
color: var(--exo-yellow, #ffd700);
font-size: 0.7rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-code-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
}
.markdown-content :global(.copy-code-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-code-btn.copied) {
color: #22c55e;
}
.markdown-content :global(.code-block-wrapper pre) {
margin: 0;
padding: 1rem;
overflow-x: auto;
background: transparent;
}
.markdown-content :global(.code-block-wrapper code) {
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.8125rem;
line-height: 1.5;
background: transparent;
}
/* Syntax highlighting - dark theme matching EXO style */
.markdown-content :global(.hljs) {
color: #e5e7eb;
}
.markdown-content :global(.hljs-keyword),
.markdown-content :global(.hljs-selector-tag),
.markdown-content :global(.hljs-literal),
.markdown-content :global(.hljs-section),
.markdown-content :global(.hljs-link) {
color: #c084fc;
}
.markdown-content :global(.hljs-string),
.markdown-content :global(.hljs-title),
.markdown-content :global(.hljs-name),
.markdown-content :global(.hljs-type),
.markdown-content :global(.hljs-attribute),
.markdown-content :global(.hljs-symbol),
.markdown-content :global(.hljs-bullet),
.markdown-content :global(.hljs-addition),
.markdown-content :global(.hljs-variable),
.markdown-content :global(.hljs-template-tag),
.markdown-content :global(.hljs-template-variable) {
color: #fbbf24;
}
.markdown-content :global(.hljs-comment),
.markdown-content :global(.hljs-quote),
.markdown-content :global(.hljs-deletion),
.markdown-content :global(.hljs-meta) {
color: #6b7280;
}
.markdown-content :global(.hljs-number),
.markdown-content :global(.hljs-regexp),
.markdown-content :global(.hljs-literal),
.markdown-content :global(.hljs-built_in) {
color: #34d399;
}
.markdown-content :global(.hljs-function),
.markdown-content :global(.hljs-class .hljs-title) {
color: #60a5fa;
}
/* KaTeX math styling */
.markdown-content :global(.katex) {
font-size: 1.1em;
}
.markdown-content :global(.katex-display) {
margin: 1rem 0;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
.markdown-content :global(.katex-display > .katex) {
text-align: center;
}
.markdown-content :global(.math-error) {
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.125rem 0.25rem;
border-radius: 0.25rem;
}
</style>

View File

@@ -1,6 +1,5 @@
<script lang="ts">
import type { DownloadProgress, NodeInfo, PlacementPreview, TopologyEdge } from '$lib/stores/app.svelte';
import { debugMode, topologyData } from '$lib/stores/app.svelte';
import type { DownloadProgress, NodeInfo, PlacementPreview } from '$lib/stores/app.svelte';
interface Props {
model: { id: string; name?: string; storage_size_megabytes?: number };
@@ -207,8 +206,12 @@ function toggleNodeDetails(nodeId: string): void {
const centerY = topoHeight / 2;
const radius = numNodes === 1 ? 0 : numNodes === 2 ? 45 : Math.min(topoWidth, topoHeight) * 0.32;
// Only use API preview data - no local estimation
// Use API preview data if available
const hasApiPreview = apiPreview !== null && apiPreview.error === null && apiPreview.memory_delta_by_node !== null;
const canFit = hasApiPreview ? true : (() => {
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
return totalAvailable >= estimatedMemory;
})();
const error = apiPreview?.error ?? null;
let placementNodes: Array<{
@@ -229,140 +232,135 @@ function toggleNodeDetails(nodeId: string): void {
modelFillHeight: number;
}> = [];
// Use API placement data directly
const memoryDelta = apiPreview?.memory_delta_by_node ?? {};
placementNodes = nodeArray.map((n, i) => {
const deltaBytes = memoryDelta[n.id] ?? 0;
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
const isUsed = deltaBytes > 0;
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
if (hasApiPreview && apiPreview.memory_delta_by_node) {
// Use API placement data
const memoryDelta = apiPreview.memory_delta_by_node;
placementNodes = nodeArray.map((n, i) => {
const deltaBytes = memoryDelta[n.id] ?? 0;
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
const isUsed = deltaBytes > 0;
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB,
currentPercent,
newPercent,
isUsed,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
};
});
} else if (apiPreview?.error) {
// API returned an error - model can't fit, show all nodes as unused
placementNodes = nodeArray.map((n, i) => {
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB: 0,
currentPercent,
newPercent: currentPercent,
isUsed: false,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: 0
};
});
} else {
// Fallback: local estimation based on sharding strategy
const memoryNeeded = estimatedMemory;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB,
currentPercent,
newPercent,
isUsed,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
};
});
if (sharding === 'Pipeline') {
const memoryPerNode = memoryNeeded / numNodes;
placementNodes = nodeArray.map((n, i) => {
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const newPercent = clampPercent(((n.usedGB + memoryPerNode) / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB: memoryPerNode,
currentPercent,
newPercent,
isUsed: true,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
};
});
} else {
let remaining = memoryNeeded;
placementNodes = nodeArray.map((n, i) => {
const allocated = Math.min(remaining, n.availableGB);
remaining -= allocated;
const isUsed = allocated > 0;
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const newPercent = clampPercent(((n.usedGB + allocated) / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB: allocated,
currentPercent,
newPercent,
isUsed,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
};
});
}
}
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
return { nodes: placementNodes, canFit: hasApiPreview, totalAvailable, topoWidth, topoHeight, error };
return { nodes: placementNodes, canFit: hasApiPreview || canFit, totalAvailable, topoWidth, topoHeight, error };
});
const canFit = $derived(apiPreview ? apiPreview.error === null : placementPreview().canFit);
const placementError = $derived(apiPreview?.error ?? null);
const nodeCount = $derived(nodeList().length);
const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, ''));
// Debug mode state
const isDebugMode = $derived(debugMode());
const topology = $derived(topologyData());
const isRdma = $derived(runtime === 'MlxIbv' || runtime === 'MlxJaccl');
// Get interface name for an IP from node data
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
if (!ip || !topology?.nodes) return null;
// Strip port if present
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Check specified node first
const node = topology.nodes[nodeId];
if (node) {
const match = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
if (match?.name) return match.name;
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped) return mapped;
}
// Fallback: check all nodes
for (const [, otherNode] of Object.entries(topology.nodes)) {
if (!otherNode) continue;
const match = otherNode.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
if (match?.name) return match.name;
const mapped = otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];
if (mapped) return mapped;
}
return null;
}
// Get directional arrow based on node positions
function getArrow(fromNode: { x: number; y: number }, toNode: { x: number; y: number }): string {
const dx = toNode.x - fromNode.x;
const dy = toNode.y - fromNode.y;
const absX = Math.abs(dx);
const absY = Math.abs(dy);
if (absX > absY * 2) {
return dx > 0 ? '→' : '←';
} else if (absY > absX * 2) {
return dy > 0 ? '↓' : '↑';
} else {
if (dx > 0 && dy > 0) return '↘';
if (dx > 0 && dy < 0) return '↗';
if (dx < 0 && dy > 0) return '↙';
return '↖';
}
}
// Get connection info for edges between two nodes
// Returns exactly one connection per direction (A→B and B→A), preferring non-loopback
function getConnectionInfo(nodeId1: string, nodeId2: string): Array<{ ip: string; iface: string | null; from: string; to: string }> {
if (!topology?.edges) return [];
// Collect candidates for each direction
const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
for (const edge of topology.edges) {
const ip = edge.sendBackIp || '?';
const iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
if (edge.source === nodeId1 && edge.target === nodeId2) {
aToBCandidates.push({ ip, iface });
} else if (edge.source === nodeId2 && edge.target === nodeId1) {
bToACandidates.push({ ip, iface });
}
}
// Pick best (prefer non-loopback)
const pickBest = (candidates: Array<{ ip: string; iface: string | null }>) => {
if (candidates.length === 0) return null;
return candidates.find(c => !c.ip.startsWith('127.')) || candidates[0];
};
const result: Array<{ ip: string; iface: string | null; from: string; to: string }> = [];
const bestAtoB = pickBest(aToBCandidates);
if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });
const bestBtoA = pickBest(bToACandidates);
if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });
return result;
}
</script>
<div class="relative group">
@@ -455,26 +453,6 @@ function toggleNodeDetails(nodeId: string): void {
<!-- Connection lines between nodes (if multiple) -->
{#if preview.nodes.length > 1}
{@const usedNodes = preview.nodes.filter(n => n.isUsed)}
{@const nodePositions = Object.fromEntries(preview.nodes.map(n => [n.id, { x: n.x, y: n.y }]))}
{@const allConnections = isDebugMode && usedNodes.length > 1 ? (() => {
const conns: Array<{ ip: string; iface: string | null; from: string; to: string; midX: number; midY: number; arrow: string }> = [];
for (let i = 0; i < usedNodes.length; i++) {
for (let j = i + 1; j < usedNodes.length; j++) {
const n1 = usedNodes[i];
const n2 = usedNodes[j];
const midX = (n1.x + n2.x) / 2;
const midY = (n1.y + n2.y) / 2;
for (const c of getConnectionInfo(n1.id, n2.id)) {
const fromPos = nodePositions[c.from];
const toPos = nodePositions[c.to];
const arrow = fromPos && toPos ? getArrow(fromPos, toPos) : '→';
conns.push({ ...c, midX, midY, arrow });
}
}
}
return conns;
})() : []}
{#each preview.nodes as node, i}
{#each preview.nodes.slice(i + 1) as node2}
<line
@@ -486,43 +464,6 @@ function toggleNodeDetails(nodeId: string): void {
/>
{/each}
{/each}
<!-- Debug: Show connection IPs/interfaces in corners -->
{#if isDebugMode && allConnections.length > 0}
{@const centerX = preview.topoWidth / 2}
{@const centerY = preview.topoHeight / 2}
{@const quadrants = {
topLeft: allConnections.filter(c => c.midX < centerX && c.midY < centerY),
topRight: allConnections.filter(c => c.midX >= centerX && c.midY < centerY),
bottomLeft: allConnections.filter(c => c.midX < centerX && c.midY >= centerY),
bottomRight: allConnections.filter(c => c.midX >= centerX && c.midY >= centerY)
}}
{@const padding = 4}
{@const lineHeight = 8}
<!-- Top Left -->
{#each quadrants.topLeft as conn, idx}
<text x={padding} y={padding + idx * lineHeight} text-anchor="start" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
</text>
{/each}
<!-- Top Right -->
{#each quadrants.topRight as conn, idx}
<text x={preview.topoWidth - padding} y={padding + idx * lineHeight} text-anchor="end" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
</text>
{/each}
<!-- Bottom Left -->
{#each quadrants.bottomLeft as conn, idx}
<text x={padding} y={preview.topoHeight - padding - (quadrants.bottomLeft.length - 1 - idx) * lineHeight} text-anchor="start" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
</text>
{/each}
<!-- Bottom Right -->
{#each quadrants.bottomRight as conn, idx}
<text x={preview.topoWidth - padding} y={preview.topoHeight - padding - (quadrants.bottomRight.length - 1 - idx) * lineHeight} text-anchor="end" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
</text>
{/each}
{/if}
{/if}
{#each preview.nodes as node}

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { onMount, onDestroy, tick } from 'svelte';
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
@@ -12,35 +12,11 @@ import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.sv
let svgContainer: SVGSVGElement | undefined = $state();
let resizeObserver: ResizeObserver | undefined;
// Optimization: Track last render state to avoid unnecessary re-renders
let lastRenderHash = '';
let lastHighlightedNodesHash = '';
let lastDimensions = { width: 0, height: 0 };
let isRendering = false;
let pendingRender = false;
const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
// Generate a hash of relevant data to detect actual changes
function generateDataHash(topologyData: typeof data, minimized: boolean, debug: boolean): string {
if (!topologyData) return 'null';
const nodes = topologyData.nodes || {};
const edges = topologyData.edges || [];
// Create a lightweight hash from key properties only
const nodeHashes = Object.entries(nodes).map(([id, n]) => {
const macmon = n.macmon_info;
return `${id}:${n.friendly_name || ''}:${macmon?.memory?.ram_usage || 0}:${macmon?.memory?.ram_total || 0}:${macmon?.temp?.gpu_temp_avg || 0}:${macmon?.gpu_usage?.[1] || 0}:${macmon?.sys_power || 0}`;
}).sort().join('|');
const edgeHash = edges.map(e => `${e.source}-${e.target}`).sort().join(',');
return `${nodeHashes}::${edgeHash}::${minimized}::${debug}`;
}
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
@@ -48,36 +24,19 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: typeof data.nodes[string]): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
if (matchFromInterfaces?.name) {
return matchFromInterfaces.name;
}
const node = data?.nodes?.[nodeId];
if (!node) return { label: '?', missing: true };
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
}
return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === ip)
);
if (matchFromInterfaces?.name) {
return { label: matchFromInterfaces.name, missing: false };
}
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
if (otherResult) return { label: otherResult, missing: false };
const mapped = node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return { label: mapped, missing: false };
}
return { label: '?', missing: true };
@@ -108,7 +67,6 @@ function wrapLine(text: string, maxLen: number): string[] {
return lines;
}
// Apple logo path for MacBook Pro screen
const APPLE_LOGO_PATH = "M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z";
const LOGO_NATIVE_WIDTH = 814;
@@ -280,7 +238,6 @@ function wrapLine(text: string, maxLen: number): string[] {
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
@@ -357,97 +314,109 @@ function wrapLine(text: string, maxLen: number): string[] {
.attr('marker-end', 'url(#arrowhead)');
}
// Collect debug labels for later positioning at edges
if (debugEnabled && entry.connections.length > 0) {
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
isTop,
mx,
my
});
}
});
const maxBoxes = 6;
const fontSize = isMinimized ? 8 : 9;
const lineGap = 2;
const labelOffsetOut = Math.max(140, minDimension * 0.38);
const labelOffsetSide = isMinimized ? 16 : 20;
const boxWidth = 170;
const maxLineLen = 26;
// Render debug labels at viewport edges/corners
if (debugEdgeLabels && debugEdgeLabels.length > 0) {
const fontSize = isMinimized ? 10 : 12;
const lineHeight = fontSize + 4;
const padding = 10;
// Helper to get arrow based on direction vector
function getArrow(fromId: string, toId: string): string {
const fromPos = positionById[fromId];
const toPos = positionById[toId];
if (!fromPos || !toPos) return '→';
const dirX = toPos.x - fromPos.x;
const dirY = toPos.y - fromPos.y;
const absX = Math.abs(dirX);
const absY = Math.abs(dirY);
if (absX > absY * 2) {
return dirX > 0 ? '→' : '←';
} else if (absY > absX * 2) {
return dirY > 0 ? '↓' : '↑';
} else {
if (dirX > 0 && dirY > 0) return '↘';
if (dirX > 0 && dirY < 0) return '↗';
if (dirX < 0 && dirY > 0) return '↙';
return '↖';
const connections = entry.connections.slice(0, maxBoxes);
if (entry.connections.length > maxBoxes) {
const remaining = entry.connections.length - maxBoxes;
connections.push({
from: '',
to: '',
ip: `(+${remaining} more)`,
ifaceLabel: '',
missingIface: false
});
}
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, typeof debugEdgeLabels> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
edges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
debugLabelsGroup.append('text')
.attr('x', baseX)
.attr('y', currentY)
.attr('text-anchor', textAnchor)
.attr('dominant-baseline', isTop ? 'hanging' : 'auto')
let dirX = mx - centerX;
let dirY = my - centerY;
const dirLen = Math.hypot(dirX, dirY);
if (dirLen < 1) {
dirX = -uy;
dirY = ux;
} else {
dirX /= dirLen;
dirY /= dirLen;
}
const nx = -dirY;
const ny = dirX;
const labelXRaw = mx + dirX * labelOffsetOut + nx * labelOffsetSide;
const labelYRaw = my + dirY * labelOffsetOut + ny * labelOffsetSide;
const clampPad = Math.min(120, minDimension * 0.12);
const labelX = Math.max(clampPad, Math.min(width - clampPad, labelXRaw));
const labelY = Math.max(clampPad, Math.min(height - clampPad, labelYRaw));
const labelGroup = debugLabelsGroup.append('g')
.attr('transform', `translate(${labelX}, ${labelY})`);
const textGroup = labelGroup.append('g');
connections.forEach((conn, idx) => {
const rawLines = conn.from && conn.to
? [
`${getNodeLabel(conn.from)}${getNodeLabel(conn.to)}`,
`${conn.ip}`,
`${conn.ifaceLabel}`
]
: [conn.ip];
const wrapped = rawLines.flatMap(line => wrapLine(line, maxLineLen));
wrapped.forEach((line, lineIdx) => {
textGroup.append('text')
.attr('x', 0)
.attr('y', (idx * (wrapped.length * (fontSize + lineGap))) + lineIdx * (fontSize + lineGap))
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'hanging')
.attr('font-size', fontSize)
.attr('font-family', 'SF Mono, monospace')
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.85)')
.text(label);
currentY += isTop ? lineHeight : -lineHeight;
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.9)')
.text(line);
});
});
});
}
const bbox = textGroup.node()?.getBBox();
if (bbox) {
const paddedWidth = Math.max(boxWidth, bbox.width + 14);
const boxHeight = bbox.height + 8;
const boxMinX = labelX - paddedWidth / 2;
const boxMaxX = labelX + paddedWidth / 2;
const boxMinY = labelY + bbox.y - 4;
const boxMaxY = boxMinY + boxHeight;
const clampPadDynamic = Math.min(140, minDimension * 0.18);
let shiftX = 0;
let shiftY = 0;
if (boxMinX < clampPadDynamic) shiftX = clampPadDynamic - boxMinX;
if (boxMaxX > width - clampPadDynamic) shiftX = (width - clampPadDynamic) - boxMaxX;
if (boxMinY < clampPadDynamic) shiftY = clampPadDynamic - boxMinY;
if (boxMaxY > height - clampPadDynamic) shiftY = (height - clampPadDynamic) - boxMaxY;
const finalX = labelX + shiftX;
const finalY = labelY + shiftY;
labelGroup.attr('transform', `translate(${finalX}, ${finalY})`);
labelGroup.insert('rect', 'g')
.attr('x', -paddedWidth / 2)
.attr('y', bbox.y - 4)
.attr('width', paddedWidth)
.attr('height', boxHeight)
.attr('rx', 4)
.attr('fill', 'rgba(0,0,0,0.75)')
.attr('stroke', 'rgba(255,255,255,0.12)')
.attr('stroke-width', 0.6);
}
}
});
// Draw nodes
const nodesGroup = svg.append('g').attr('class', 'nodes-group');
@@ -956,59 +925,16 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Throttled render function to prevent too-frequent updates
function scheduleRender() {
if (isRendering) {
pendingRender = true;
return;
}
isRendering = true;
requestAnimationFrame(() => {
renderGraph();
isRendering = false;
if (pendingRender) {
pendingRender = false;
scheduleRender();
}
});
}
$effect(() => {
if (!data || !svgContainer) return;
// Generate hash of current state
const currentHash = generateDataHash(data, isMinimized, debugEnabled);
const highlightHash = Array.from(highlightedNodes).sort().join(',');
// Get current dimensions
const rect = svgContainer.getBoundingClientRect();
const dimensionsChanged = rect.width !== lastDimensions.width || rect.height !== lastDimensions.height;
// Only re-render if something actually changed
if (currentHash !== lastRenderHash || highlightHash !== lastHighlightedNodesHash || dimensionsChanged) {
lastRenderHash = currentHash;
lastHighlightedNodesHash = highlightHash;
lastDimensions = { width: rect.width, height: rect.height };
scheduleRender();
if (data) {
renderGraph();
}
});
onMount(() => {
if (svgContainer) {
// Use a debounced resize observer to prevent rapid re-renders
let resizeTimeout: ReturnType<typeof setTimeout> | null = null;
resizeObserver = new ResizeObserver(() => {
if (resizeTimeout) clearTimeout(resizeTimeout);
resizeTimeout = setTimeout(() => {
const rect = svgContainer!.getBoundingClientRect();
if (rect.width !== lastDimensions.width || rect.height !== lastDimensions.height) {
lastDimensions = { width: rect.width, height: rect.height };
scheduleRender();
}
}, 100);
renderGraph();
});
resizeObserver.observe(svgContainer);
}
@@ -1036,20 +962,10 @@ function wrapLine(text: string, maxLen: number): string[] {
stroke-width: 1px;
stroke-dasharray: 4, 4;
opacity: 0.8;
/* Slower animation = less GPU usage */
animation: flowAnimation 2s linear infinite;
/* GPU optimization */
will-change: stroke-dashoffset;
animation: flowAnimation 0.75s linear infinite;
}
@keyframes flowAnimation {
from { stroke-dashoffset: 0; }
to { stroke-dashoffset: -10; }
}
/* Respect reduced motion preference */
@media (prefers-reduced-motion: reduce) {
:global(.graph-link) {
animation: none;
}
}
</style>

View File

@@ -4,5 +4,4 @@ export { default as ChatMessages } from './ChatMessages.svelte';
export { default as ChatAttachments } from './ChatAttachments.svelte';
export { default as ChatSidebar } from './ChatSidebar.svelte';
export { default as ModelCard } from './ModelCard.svelte';
export { default as MarkdownContent } from './MarkdownContent.svelte';

View File

@@ -96,7 +96,7 @@ interface RawNodeProfile {
interface RawTopologyNode {
nodeId: string;
nodeProfile: RawNodeProfile;
nodeProfile?: RawNodeProfile;
}
interface RawTopologyConnection {
@@ -105,9 +105,13 @@ interface RawTopologyConnection {
sendBackMultiaddr?: { multiaddr?: string; address?: string; ip_address?: string } | string;
}
// Connection can be an object or a tuple [source, target, metadata]
type RawConnectionItem = RawTopologyConnection | [string, string, { sinkMultiaddr?: { ip_address?: string; address?: string } }?];
interface RawTopology {
nodes: RawTopologyNode[];
connections?: RawTopologyConnection[];
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
connections?: RawConnectionItem[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
@@ -198,9 +202,17 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === 'string' ? node : node.nodeId;
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode = typeof node === 'object' ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
@@ -238,7 +250,7 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
}
}
nodes[node.nodeId] = {
nodes[nodeId] = {
system_info: {
model_id: profile?.modelId ?? 'Unknown',
chip: profile?.chipId,
@@ -260,14 +272,34 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
};
}
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
for (const conn of raw.connections || []) {
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
if (conn.localNodeId === conn.sendBackNodeId) continue;
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
let localNodeId: string | undefined;
let sendBackNodeId: string | undefined;
let sendBackMultiaddr: { multiaddr?: string; address?: string; ip_address?: string } | string | undefined;
// Check if it's a tuple format [source, target, metadata]
if (Array.isArray(conn)) {
localNodeId = conn[0] as string;
sendBackNodeId = conn[1] as string;
const metadata = conn[2] as { sinkMultiaddr?: { ip_address?: string; address?: string } } | undefined;
if (metadata?.sinkMultiaddr) {
sendBackMultiaddr = metadata.sinkMultiaddr;
}
} else {
// Object format with localNodeId/sendBackNodeId
localNodeId = conn.localNodeId;
sendBackNodeId = conn.sendBackNodeId;
sendBackMultiaddr = conn.sendBackMultiaddr;
}
if (!localNodeId || !sendBackNodeId) continue;
if (localNodeId === sendBackNodeId) continue;
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
let sendBackIp: string | undefined;
if (conn.sendBackMultiaddr) {
const multi = conn.sendBackMultiaddr;
if (sendBackMultiaddr) {
const multi = sendBackMultiaddr;
if (typeof multi === 'string') {
sendBackIp = extractIpFromMultiaddr(multi);
} else {
@@ -276,8 +308,8 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
}
edges.push({
source: conn.localNodeId,
target: conn.sendBackNodeId,
source: localNodeId,
target: sendBackNodeId,
sendBackIp
});
}
@@ -297,35 +329,6 @@ function extractIpFromMultiaddr(ma?: string): string | undefined {
return undefined;
}
// Deep comparison utility for preventing unnecessary state updates
function shallowEqual(a: unknown, b: unknown): boolean {
if (a === b) return true;
if (a === null || b === null) return false;
if (typeof a !== 'object' || typeof b !== 'object') return false;
const aObj = a as Record<string, unknown>;
const bObj = b as Record<string, unknown>;
const aKeys = Object.keys(aObj);
const bKeys = Object.keys(bObj);
if (aKeys.length !== bKeys.length) return false;
for (const key of aKeys) {
if (aObj[key] !== bObj[key]) return false;
}
return true;
}
// Faster JSON comparison for complex nested objects
function jsonEqual(a: unknown, b: unknown): boolean {
if (a === b) return true;
try {
return JSON.stringify(a) === JSON.stringify(b);
} catch {
return false;
}
}
class AppStore {
// Conversation state
conversations = $state<Conversation[]>([]);
@@ -356,49 +359,19 @@ class AppStore {
isTopologyMinimized = $state(false);
isSidebarOpen = $state(false); // Hidden by default, shown when in chat mode
debugMode = $state(false);
topologyOnlyMode = $state(false);
chatSidebarVisible = $state(true); // Shown by default
// Visibility state - used to pause polling when tab is hidden
private isPageVisible = true;
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
// Cache for comparison - prevents unnecessary reactivity
private lastTopologyJson = '';
private lastInstancesJson = '';
private lastRunnersJson = '';
private lastDownloadsJson = '';
constructor() {
if (browser) {
this.startPolling();
this.loadConversationsFromStorage();
this.loadDebugModeFromStorage();
this.loadTopologyOnlyModeFromStorage();
this.loadChatSidebarVisibleFromStorage();
this.setupVisibilityListener();
}
}
/**
* Listen for page visibility changes to pause polling when hidden
*/
private setupVisibilityListener() {
if (typeof document === 'undefined') return;
document.addEventListener('visibilitychange', () => {
this.isPageVisible = document.visibilityState === 'visible';
if (this.isPageVisible) {
// Resume polling when page becomes visible
this.fetchState();
}
});
}
/**
* Load conversations from localStorage
*/
@@ -453,44 +426,6 @@ class AppStore {
}
}
private loadTopologyOnlyModeFromStorage() {
try {
const stored = localStorage.getItem('exo-topology-only-mode');
if (stored !== null) {
this.topologyOnlyMode = stored === 'true';
}
} catch (error) {
console.error('Failed to load topology only mode:', error);
}
}
private saveTopologyOnlyModeToStorage() {
try {
localStorage.setItem('exo-topology-only-mode', this.topologyOnlyMode ? 'true' : 'false');
} catch (error) {
console.error('Failed to save topology only mode:', error);
}
}
private loadChatSidebarVisibleFromStorage() {
try {
const stored = localStorage.getItem('exo-chat-sidebar-visible');
if (stored !== null) {
this.chatSidebarVisible = stored === 'true';
}
} catch (error) {
console.error('Failed to load chat sidebar visibility:', error);
}
}
private saveChatSidebarVisibleToStorage() {
try {
localStorage.setItem('exo-chat-sidebar-visible', this.chatSidebarVisible ? 'true' : 'false');
} catch (error) {
console.error('Failed to save chat sidebar visibility:', error);
}
}
/**
* Create a new conversation
*/
@@ -795,39 +730,9 @@ class AppStore {
this.saveDebugModeToStorage();
}
getTopologyOnlyMode(): boolean {
return this.topologyOnlyMode;
}
setTopologyOnlyMode(enabled: boolean) {
this.topologyOnlyMode = enabled;
this.saveTopologyOnlyModeToStorage();
}
toggleTopologyOnlyMode() {
this.topologyOnlyMode = !this.topologyOnlyMode;
this.saveTopologyOnlyModeToStorage();
}
getChatSidebarVisible(): boolean {
return this.chatSidebarVisible;
}
setChatSidebarVisible(visible: boolean) {
this.chatSidebarVisible = visible;
this.saveChatSidebarVisibleToStorage();
}
toggleChatSidebarVisible() {
this.chatSidebarVisible = !this.chatSidebarVisible;
this.saveChatSidebarVisibleToStorage();
}
startPolling() {
this.fetchState();
// Poll every 2 seconds instead of 1 second - reduces CPU/GPU load by 50%
// Data comparison ensures we only update when something actually changes
this.fetchInterval = setInterval(() => this.fetchState(), 2000);
this.fetchInterval = setInterval(() => this.fetchState(), 1000);
}
stopPolling() {
@@ -839,9 +744,6 @@ class AppStore {
}
async fetchState() {
// Skip polling when page is hidden to save resources
if (!this.isPageVisible) return;
try {
const response = await fetch('/state');
if (!response.ok) {
@@ -849,44 +751,19 @@ class AppStore {
}
const data: RawStateResponse = await response.json();
// Only update topology if it actually changed (prevents unnecessary D3 re-renders)
if (data.topology) {
const newTopology = transformTopology(data.topology, data.nodeProfiles);
const newTopologyJson = JSON.stringify(newTopology);
if (newTopologyJson !== this.lastTopologyJson) {
this.lastTopologyJson = newTopologyJson;
this.topologyData = newTopology;
}
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
}
// Only update instances if changed
if (data.instances) {
const newInstancesJson = JSON.stringify(data.instances);
if (newInstancesJson !== this.lastInstancesJson) {
this.lastInstancesJson = newInstancesJson;
this.instances = data.instances;
this.refreshConversationModelFromInstances();
}
this.instances = data.instances;
this.refreshConversationModelFromInstances();
}
// Only update runners if changed
if (data.runners) {
const newRunnersJson = JSON.stringify(data.runners);
if (newRunnersJson !== this.lastRunnersJson) {
this.lastRunnersJson = newRunnersJson;
this.runners = data.runners;
}
this.runners = data.runners;
}
// Only update downloads if changed
if (data.downloads) {
const newDownloadsJson = JSON.stringify(data.downloads);
if (newDownloadsJson !== this.lastDownloadsJson) {
this.lastDownloadsJson = newDownloadsJson;
this.downloads = data.downloads;
}
this.downloads = data.downloads;
}
this.lastUpdate = Date.now();
} catch (error) {
console.error('Error fetching state:', error);
@@ -1043,6 +920,8 @@ class AppStore {
if (lastUserIndex === -1) return;
const lastUserMessage = this.messages[lastUserIndex];
// Remove any messages after the user message
this.messages = this.messages.slice(0, lastUserIndex + 1);
@@ -1083,10 +962,7 @@ class AppStore {
}
if (!modelToUse) {
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = 'Error: No model available. Please launch an instance first.';
}
assistantMessage.content = 'Error: No model available. Please launch an instance first.';
this.isLoading = false;
this.updateActiveConversation();
return;
@@ -1104,10 +980,7 @@ class AppStore {
if (!response.ok) {
const errorText = await response.text();
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = `Error: ${response.status} - ${errorText}`;
}
assistantMessage.content = `Error: ${response.status} - ${errorText}`;
this.isLoading = false;
this.updateActiveConversation();
return;
@@ -1115,10 +988,7 @@ class AppStore {
const reader = response.body?.getReader();
if (!reader) {
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = 'Error: No response stream available';
}
assistantMessage.content = 'Error: No response stream available';
this.isLoading = false;
this.updateActiveConversation();
return;
@@ -1146,16 +1016,9 @@ class AppStore {
const delta = json.choices?.[0]?.delta?.content;
if (delta) {
fullContent += delta;
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
const { displayContent } = this.stripThinkingTags(fullContent);
this.currentResponse = displayContent;
// Update the assistant message in place (triggers Svelte reactivity)
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
}
this.persistActiveConversation();
assistantMessage.content = displayContent;
}
} catch {
// Skip malformed JSON
@@ -1164,25 +1027,16 @@ class AppStore {
}
}
// Final cleanup of the message
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
}
this.persistActiveConversation();
} catch (error) {
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
}
this.persistActiveConversation();
} finally {
this.isLoading = false;
const { displayContent } = this.stripThinkingTags(fullContent);
assistantMessage.content = displayContent;
this.currentResponse = '';
this.updateActiveConversation();
} catch (error) {
assistantMessage.content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
this.updateActiveConversation();
} finally {
this.isLoading = false;
}
}
@@ -1542,8 +1396,6 @@ export const lastUpdate = () => appStore.lastUpdate;
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
export const selectedChatModel = () => appStore.selectedChatModel;
export const debugMode = () => appStore.getDebugMode();
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
// Actions
export const startChat = () => appStore.startChat();
@@ -1571,9 +1423,5 @@ export const isSidebarOpen = () => appStore.isSidebarOpen;
export const toggleSidebar = () => appStore.toggleSidebar();
export const toggleDebugMode = () => appStore.toggleDebugMode();
export const setDebugMode = (enabled: boolean) => appStore.setDebugMode(enabled);
export const toggleTopologyOnlyMode = () => appStore.toggleTopologyOnlyMode();
export const setTopologyOnlyMode = (enabled: boolean) => appStore.setTopologyOnlyMode(enabled);
export const toggleChatSidebarVisible = () => appStore.toggleChatSidebarVisible();
export const setChatSidebarVisible = (visible: boolean) => appStore.setChatSidebarVisible(visible);
export const refreshState = () => appStore.fetchState();

View File

@@ -1,25 +1,7 @@
<script lang="ts">
import '../app.css';
import { onMount } from 'svelte';
import { browser } from '$app/environment';
let { children } = $props();
let isPageHidden = $state(false);
onMount(() => {
if (!browser) return;
// Listen for visibility changes to pause animations when hidden
const handleVisibilityChange = () => {
isPageHidden = document.visibilityState === 'hidden';
};
document.addEventListener('visibilitychange', handleVisibilityChange);
return () => {
document.removeEventListener('visibilitychange', handleVisibilityChange);
};
});
</script>
<svelte:head>
@@ -27,7 +9,7 @@
<meta name="description" content="EXO - Distributed AI Cluster Dashboard" />
</svelte:head>
<div class="min-h-screen bg-background text-foreground" data-page-hidden={isPageHidden}>
<div class="min-h-screen bg-background text-foreground">
{@render children?.()}
</div>

View File

@@ -18,10 +18,6 @@
selectedChatModel,
debugMode,
toggleDebugMode,
topologyOnlyMode,
toggleTopologyOnlyMode,
chatSidebarVisible,
toggleChatSidebarVisible,
type DownloadProgress,
type PlacementPreview
} from '$lib/stores/app.svelte';
@@ -41,8 +37,6 @@
const selectedModelId = $derived(selectedPreviewModelId());
const loadingPreviews = $derived(isLoadingPreviews());
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
@@ -99,35 +93,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Compute highlighted nodes from hovered instance or hovered preview
// Memoized to avoid creating new Sets on every render
let lastHighlightedNodesKey = '';
let cachedHighlightedNodes: Set<string> = new Set();
const highlightedNodes = $derived(() => {
// Create a key for the current state to enable memoization
const previewKey = Array.from(hoveredPreviewNodes).sort().join(',');
const currentKey = `${hoveredInstanceId || 'null'}:${previewKey}`;
// Return cached value if nothing changed
if (currentKey === lastHighlightedNodesKey) {
return cachedHighlightedNodes;
}
lastHighlightedNodesKey = currentKey;
// First check instance hover
if (hoveredInstanceId) {
const instanceWrapped = instanceData[hoveredInstanceId];
cachedHighlightedNodes = unwrapInstanceNodes(instanceWrapped);
return cachedHighlightedNodes;
return unwrapInstanceNodes(instanceWrapped);
}
// Then check preview hover
if (hoveredPreviewNodes.size > 0) {
cachedHighlightedNodes = hoveredPreviewNodes;
return cachedHighlightedNodes;
return hoveredPreviewNodes;
}
cachedHighlightedNodes = new Set<string>();
return cachedHighlightedNodes;
return new Set<string>();
});
// Helper to estimate memory from model ID (mirrors ModelCard logic)
@@ -496,7 +472,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const progress = parseDownloadProgress(downloadPayload);
if (progress) {
// Sum all values across nodes - each node downloads independently
totalBytes += progress.totalBytes;
downloadedBytes += progress.downloadedBytes;
totalSpeed += progress.speed;
@@ -514,17 +489,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
return { isDownloading: false, progress: null, perNode: [] };
}
// ETA = total remaining bytes / total speed across all nodes
const remainingBytes = totalBytes - downloadedBytes;
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
return {
isDownloading: true,
progress: {
totalBytes,
downloadedBytes,
speed: totalSpeed,
etaMs,
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
completedFiles,
totalFiles,
@@ -534,13 +505,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
};
}
// Debug: Log downloads data when it changes (disabled in production for performance)
// Uncomment for debugging:
// $effect(() => {
// if (downloadsData && Object.keys(downloadsData).length > 0) {
// console.log('[Download Debug] Current downloads:', downloadsData);
// }
// });
// Debug: Log downloads data when it changes
$effect(() => {
if (downloadsData && Object.keys(downloadsData).length > 0) {
console.log('[Download Debug] Current downloads:', downloadsData);
}
});
// Helper to get download status for an instance
function getInstanceDownloadStatus(instanceId: string, instanceWrapped: unknown): {
@@ -606,7 +576,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const progress = parseDownloadProgress(downloadPayload);
if (progress) {
// Sum all values across nodes - each node downloads independently
totalBytes += progress.totalBytes;
downloadedBytes += progress.downloadedBytes;
totalSpeed += progress.speed;
@@ -627,17 +596,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
return { isDownloading: false, progress: null, statusText: statusInfo.statusText, perNode: [] };
}
// ETA = total remaining bytes / total speed across all nodes
const remainingBytes = totalBytes - downloadedBytes;
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
return {
isDownloading: true,
progress: {
totalBytes,
downloadedBytes,
speed: totalSpeed,
etaMs,
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
completedFiles,
totalFiles,
@@ -653,12 +618,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function getStatusColor(statusText: string): string {
switch (statusText) {
case 'FAILED': return 'text-red-400';
case 'SHUTDOWN': return 'text-gray-400';
case 'DOWNLOADING': return 'text-blue-400';
case 'LOADING':
case 'WARMING UP':
case 'WAITING':
case 'INITIALIZING': return 'text-yellow-400';
case 'WAITING': return 'text-yellow-400';
case 'RUNNING': return 'text-teal-400';
case 'READY':
case 'LOADED': return 'text-green-400';
@@ -681,15 +644,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!r) return null;
const [kind] = getTagged(r);
const statusMap: Record<string, string> = {
RunnerWaitingForInitialization: 'WaitingForInitialization',
RunnerInitializingBackend: 'InitializingBackend',
RunnerWaitingForModel: 'WaitingForModel',
RunnerLoading: 'Loading',
RunnerLoaded: 'Loaded',
RunnerWarmingUp: 'WarmingUp',
RunnerReady: 'Ready',
RunnerRunning: 'Running',
RunnerShutdown: 'Shutdown',
RunnerFailed: 'Failed',
};
return kind ? statusMap[kind] || null : null;
@@ -700,15 +660,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
if (has('WarmingUp')) return { statusText: 'WARMING UP', statusClass: 'starting' };
if (has('Running')) return { statusText: 'RUNNING', statusClass: 'running' };
if (has('Ready')) return { statusText: 'READY', statusClass: 'loaded' };
if (has('Loaded')) return { statusText: 'LOADED', statusClass: 'loaded' };
if (has('WaitingForModel')) return { statusText: 'WAITING', statusClass: 'starting' };
if (has('InitializingBackend')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
if (has('WaitingForInitialization')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
return { statusText: 'RUNNING', statusClass: 'active' };
}
@@ -1150,47 +1107,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="shooting-star" style="top: 50%; left: 40%; --duration: 45s; --delay: 30s;"></div>
</div>
{#if !topologyOnlyEnabled}
<HeaderNav
showHome={chatStarted}
onHome={handleGoHome}
showSidebarToggle={true}
sidebarVisible={sidebarVisible}
onToggleSidebar={toggleChatSidebarVisible}
/>
{/if}
<HeaderNav showHome={chatStarted} onHome={handleGoHome} />
<!-- Main Content -->
<main class="flex-1 flex overflow-hidden relative">
<!-- Left: Conversation History Sidebar (hidden in topology-only mode or when toggled off) -->
{#if !topologyOnlyEnabled && sidebarVisible}
<!-- Left: Conversation History Sidebar (always visible) -->
<div class="w-80 flex-shrink-0 border-r border-exo-yellow/10">
<ChatSidebar class="h-full" />
</div>
{/if}
{#if topologyOnlyEnabled}
<!-- TOPOLOGY ONLY MODE: Full-screen topology -->
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
<!-- Exit topology-only mode button -->
<button
type="button"
onclick={toggleTopologyOnlyMode}
class="absolute bottom-4 right-4 p-2 rounded border border-exo-yellow/30 bg-exo-dark-gray/80 hover:border-exo-yellow/50 hover:bg-exo-dark-gray transition-colors cursor-pointer backdrop-blur-sm"
title="Exit topology only mode"
>
<svg class="w-5 h-5 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="5" r="2" fill="currentColor" />
<circle cx="5" cy="19" r="2" fill="currentColor" />
<circle cx="19" cy="19" r="2" fill="currentColor" />
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
</svg>
</button>
</div>
</div>
{:else if !chatStarted}
{#if !chatStarted}
<!-- WELCOME STATE: Topology + Instance Controls (no left sidebar for cleaner look) -->
<div class="flex-1 flex overflow-visible relative" in:fade={{ duration: 300 }} out:fade={{ duration: 200 }}>
@@ -1374,15 +1300,14 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{:else}
{#each nodeProg.progress.files as f}
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
{@const isFileComplete = filePercent >= 100}
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
<span class="truncate pr-2">{f.name}</span>
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
<span class="text-white/80">{filePercent.toFixed(1)}%</span>
</div>
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
style="width: {filePercent.toFixed(1)}%"
></div>
</div>
@@ -1686,13 +1611,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
in:fade={{ duration: 300, delay: 100 }}
>
<div class="flex-1 overflow-y-auto px-8 py-6" bind:this={chatScrollRef}>
<div class="max-w-7xl mx-auto">
<div class="max-w-3xl mx-auto">
<ChatMessages scrollParent={chatScrollRef} />
</div>
</div>
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<div class="max-w-3xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} />
</div>
</div>
@@ -1730,7 +1655,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<!-- Panel Header -->
<div class="flex items-center gap-2 mb-4">
<div class="w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse"></div>
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
<h3 class="text-sm text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
</div>
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
@@ -1776,28 +1701,28 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex justify-between items-start mb-2 pl-2">
<div class="flex items-center gap-2">
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
<span class="text-exo-light-gray font-mono text-xs tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400/80 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-exo-yellow text-sm font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
class="inline-flex items-center gap-1 text-[10px] text-white/60 hover:text-exo-yellow transition-colors mt-0.5"
href={`https://huggingface.co/${instanceModelId}`}
target="_blank"
rel="noreferrer noopener"
aria-label="View model on Hugging Face"
>
<span>Hugging Face</span>
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<svg class="w-3 h-3" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M14 3h7v7"/>
<path d="M10 14l11-11"/>
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6"/>
@@ -1808,84 +1733,68 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-white/60 text-xs font-mono">{instanceInfo.nodeNames.join(', ')}</div>
{/if}
{#if debugEnabled && instanceConnections.length > 0}
<div class="mt-2 space-y-1">
{#each instanceConnections as conn}
<div class="text-[11px] leading-snug font-mono text-white/70">
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
<div class="mt-1 space-y-0.5">
{#each instanceConnections as conn}
<div class="text-[10px] leading-snug font-mono text-white/70">
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
</div>
{/each}
</div>
{/if}
<!-- Download Progress -->
{#if downloadInfo.isDownloading && downloadInfo.progress}
<div class="mt-2 space-y-1">
<div class="flex justify-between text-sm font-mono">
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
</div>
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
style="width: {downloadInfo.progress.percentage}%"
></div>
</div>
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
</div>
{/each}
</div>
{/if}
<!-- Download Progress -->
{#if downloadInfo.isDownloading && downloadInfo.progress}
<div class="mt-2 space-y-1">
<div class="flex justify-between text-xs font-mono">
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
</div>
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
style="width: {downloadInfo.progress.percentage}%"
></div>
</div>
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
</div>
</div>
{#if downloadInfo.perNode.length > 0}
<div class="mt-2 space-y-2 max-h-48 overflow-y-auto pr-1">
{#each downloadInfo.perNode as nodeProg}
{@const nodePercent = Math.min(100, Math.max(0, nodeProg.progress.percentage))}
{@const isExpanded = instanceDownloadExpandedNodes.has(nodeProg.nodeId)}
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
<button
type="button"
class="w-full text-left space-y-1.5"
onclick={() => toggleInstanceDownloadDetails(nodeProg.nodeId)}
>
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
{#if downloadInfo.perNode.length > 0}
<div class="mt-2 space-y-1.5 max-h-48 overflow-y-auto pr-1">
{#each downloadInfo.perNode as nodeProg}
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
<span class="text-white/80 truncate pr-2">{nodeProg.nodeName}</span>
<span class="flex items-center gap-1 text-blue-300">
{nodePercent.toFixed(1)}%
<svg class="w-3 h-3 text-exo-light-gray" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
<path d="M6 8l4 4 4-4" class={isExpanded ? 'transform rotate-180 origin-center transition-transform duration-150' : 'transition-transform duration-150'}></path>
</svg>
</span>
<span class="text-blue-300">{Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%</span>
</div>
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mb-1.5">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
style="width: {nodePercent.toFixed(1)}%"
class="absolute inset-y-0 left-0 bg-blue-500/80 transition-all duration-300"
style="width: {Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%"
></div>
</div>
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
<span>{formatBytes(nodeProg.progress.downloadedBytes)} / {formatBytes(nodeProg.progress.totalBytes)}</span>
<span>{formatSpeed(nodeProg.progress.speed)} • ETA {formatEta(nodeProg.progress.etaMs)}</span>
</div>
</button>
{#if isExpanded}
<div class="mt-2 space-y-1.5">
{#if nodeProg.progress.files.length === 0}
<div class="text-[11px] font-mono text-exo-light-gray/70">No file details reported.</div>
{:else}
{#each nodeProg.progress.files as f}
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
{@const isFileComplete = filePercent >= 100}
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
{#if nodeProg.progress.files.length > 0}
{@const inProgressFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) < 100)}
{@const completedFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) >= 100)}
{#if inProgressFiles.length > 0}
<div class="space-y-1">
{#each inProgressFiles as f}
<div class="text-[10px] font-mono text-exo-light-gray/80">
<div class="flex items-center justify-between">
<span class="truncate pr-2">{f.name}</span>
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
<span class="text-white/70">{Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%</span>
</div>
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
<div class="relative h-1 bg-exo-black/50 rounded-sm overflow-hidden mt-0.5">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
style="width: {filePercent.toFixed(1)}%"
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70"
style="width: {Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%"
></div>
</div>
<div class="flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5">
@@ -1894,17 +1803,27 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
</div>
</div>
{/each}
{/if}
</div>
</div>
{/if}
{#if completedFiles.length > 0}
<div class="pt-1 space-y-0.5">
{#each completedFiles as f}
<div class="text-[10px] font-mono text-exo-light-gray/70 flex items-center justify-between">
<span class="truncate pr-2">{f.name}</span>
<span class="text-white/60">100%</span>
</div>
{/each}
</div>
{/if}
{/if}
</div>
{/each}
</div>
</div>
{/each}
</div>
{/if}
<div class="text-sm text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
{:else}
<div class="text-sm {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
{/if}
<div class="text-xs text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
{:else}
<div class="text-xs {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
{/if}
</div>
</div>
</div>

View File

@@ -345,19 +345,13 @@
<div class="rounded border border-exo-medium-gray/30 bg-exo-dark-gray/60 p-3 space-y-2">
<div class="flex items-center justify-between gap-3">
<div class="min-w-0 space-y-0.5">
<div
class="text-xs font-mono text-white truncate"
title={model.prettyName ?? model.modelId}
>{model.prettyName ?? model.modelId}</div>
<div
class="text-[10px] text-exo-light-gray font-mono truncate"
title={model.modelId}
>{model.modelId}</div>
{#if model.status !== 'completed'}
<div class="text-[11px] text-exo-light-gray font-mono">
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
</div>
{/if}
<div class="text-sm font-mono text-white truncate">{model.prettyName ?? model.modelId}</div>
<div class="text-[11px] text-exo-light-gray font-mono truncate">
{model.modelId}
</div>
<div class="text-[11px] text-exo-light-gray font-mono">
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
</div>
</div>
<div class="flex items-center gap-2">
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
@@ -432,14 +426,14 @@
<style>
.downloads-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
}
@media (min-width: 1024px) {
.downloads-grid {
grid-template-columns: repeat(3, minmax(0, 1fr));
}
}
@media (min-width: 1600px) {
@media (min-width: 1440px) {
.downloads-grid {
grid-template-columns: repeat(4, minmax(0, 1fr));
}

View File

@@ -29,12 +29,10 @@ dependencies = [
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"bidict>=0.23.1",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx>=0.30.1",
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
]
[project.scripts]

View File

@@ -13,12 +13,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -27,13 +21,11 @@ from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
FinishReason,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -64,7 +56,7 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
HIDE_THINKING = False
def chunk_to_response(
@@ -169,9 +161,7 @@ class API:
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/v1/chat/completions")(self.chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -187,32 +177,17 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=command.model_meta,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
raise HTTPException(
status_code=400,
detail=f"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB",
)
command = CreateInstance(
instance=instance,
)
command = CreateInstance(instance=payload.instance)
await self._send(command)
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
)
async def get_placement(
@@ -232,6 +207,7 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -287,6 +263,7 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -377,52 +354,32 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
async def _generate_chat_stream(
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
is_thinking = False
with recv as token_chunks:
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
async for chunk in token_chunks:
if HIDE_THINKING:
if chunk.text == "<think>":
is_thinking = True
if chunk.text == "</think>":
is_thinking = False
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
if not (is_thinking and HIDE_THINKING):
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -437,59 +394,6 @@ class API:
await self._send(command)
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
),
finish_reason=finish_reason,
)
],
)
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
@@ -497,12 +401,10 @@ class API:
async def chat_completions(
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
) -> StreamingResponse:
"""Handle chat completions with proper streaming response."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -517,21 +419,17 @@ class API:
request_params=payload,
)
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
return StreamingResponse(
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
return total_available
@@ -545,8 +443,6 @@ class API:
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
)
for card in MODEL_CARDS.values()
]
@@ -563,7 +459,7 @@ class API:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._applystate)
tg.start_soon(self._pause_on_new_election)
print_startup_banner(self.port)
await serve(
@@ -575,7 +471,7 @@ class API:
self.command_sender.close()
self.global_event_receiver.close()
async def _apply_state(self):
async def _applystate(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.origin != self.session_id.master_node_id:

View File

@@ -158,6 +158,7 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -200,9 +201,7 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:

View File

@@ -6,10 +6,11 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_mlx_ring_hosts_by_node,
get_mlx_jaccl_devices_matrix,
get_shard_assignments,
get_smallest_cycles,
)
@@ -19,9 +20,10 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.topology import NodeInfo
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -50,19 +52,16 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
cycles = topology.get_cycles() + [[node] for node in all_nodes]
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
candidate_cycles, node_profiles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
@@ -70,13 +69,15 @@ def place_instance(
smallest_tb_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
if topology.get_subgraph_from_nodes(
[node.node_id for node in cycle]
).is_thunderbolt_cycle([node.node_id for node in cycle])
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node.node_id) for node in cycle)
@@ -85,11 +86,7 @@ def place_instance(
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
(node.node_profile.memory.ram_available for node in cycle),
start=Memory(),
),
)
@@ -98,14 +95,16 @@ def place_instance(
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
[node.node_id for node in selected_cycle]
)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected ibv for a single node instance; falling back to MlxRing"
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -113,33 +112,32 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
selected_cycle,
coordinator=selected_cycle[0].node_id,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
hosts_by_node = get_mlx_ring_hosts_by_node(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
)
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
hosts=[
Host(
ip=host.ip,
port=random_ephemeral_port(),
)
for host in hosts
],
)
return target_instances

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import TypeGuard, cast
from collections.abc import Generator, Mapping
from loguru import logger
from pydantic import BaseModel
@@ -9,7 +8,7 @@ from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -24,27 +23,32 @@ class NodeWithProfile(BaseModel):
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[list[NodeWithProfile]]:
filtered_cycles: list[list[NodeWithProfile]] = []
for cycle in cycles:
if not narrow_all_nodes(cycle):
if not all(node in node_profiles for node in cycle):
continue
total_mem = sum(
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[NodeInfo], cycle))
filtered_cycles.append(
[
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
for node in cycle
]
)
return filtered_cycles
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
def get_smallest_cycles(
cycles: list[list[NodeWithProfile]],
) -> list[list[NodeWithProfile]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
@@ -135,11 +139,9 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeInfo],
selected_cycle: list[NodeWithProfile],
sharding: Sharding,
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
@@ -176,17 +178,16 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
for src, sink, connection in cycle_digraph.list_connections():
if not isinstance(connection, SocketConnection):
continue
if src == current_node and sink == next_node:
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
@@ -194,8 +195,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
return hosts
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
def get_mlx_jaccl_devices_matrix(
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -204,6 +204,7 @@ def get_mlx_ibv_devices_matrix(
to device j, or None if no connection exists or no interface name is found.
Diagonal elements are always None.
"""
selected_cycle = list(cycle_digraph.list_nodes())
num_nodes = len(selected_cycle)
matrix: list[list[str | None]] = [
[None for _ in range(num_nodes)] for _ in range(num_nodes)
@@ -214,191 +215,55 @@ def get_mlx_ibv_devices_matrix(
if i == j:
continue
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
break
else:
logger.warning(
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current ibv backend requires all-to-all rdma connections"
"Current jaccl backend requires all-to-all RDMA connections"
)
return matrix
def _find_connection_ip(
node_i: NodeInfo,
node_j: NodeInfo,
node_i: NodeId,
node_j: NodeId,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def _find_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
return None
def _find_ip_prioritised(
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
Priority order:
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
en0_ip = iface_map.get("en0")
if en0_ip:
return en0_ip
en1_ip = iface_map.get("en1")
if en1_ip:
return en1_ip
non_thunderbolt_ip = next(
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
)
if non_thunderbolt_ip:
return non_thunderbolt_ip
if ips:
return ips[0][0]
return None
def get_mlx_ring_hosts_by_node(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
Each node gets a list where:
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
- Left/right neighbors: actual connection IPs
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
"""
world_size = len(selected_cycle)
if world_size == 0:
return {}
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
hosts_for_node: list[Host] = []
for idx, other_node in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
if idx not in {left_rank, right_rank}:
# Placeholder IP from RFC 5737 TEST-NET-2
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
)
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
hosts_by_node[node_id] = hosts_for_node
return hosts_by_node
) -> Generator[str]:
"""Find all IP addresses that connect node i to node j."""
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address
def get_mlx_jaccl_coordinators(
selected_cycle: list[NodeInfo],
coordinator: NodeId,
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
selected_cycle = list(cycle_digraph.list_nodes())
logger.info(f"Selecting coordinator: {coordinator}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
return "0.0.0.0"
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
return ip
logger.warning(
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}

View File

@@ -1,67 +1,36 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
return _create_connection
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)

View File

@@ -19,15 +19,13 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodePerformanceMeasured,
NodeGatheredInfo,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
MemoryUsage,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -83,21 +81,14 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodePerformanceMeasured(
NodeGatheredInfo(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
)
),
@@ -123,8 +114,6 @@ async def test_master():
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
supports_tensor=True,
),
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
@@ -163,40 +152,34 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
# Validate the shard assignments
expected_shard_assignments = ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()
)[0]
assert events[1].event.instance == MlxRingInstance(
instance_id=events[1].event.instance.instance_id,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
supports_tensor=True,
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
),
hosts=[],
)
assert created_instance.shard_assignments == expected_shard_assignments
# For single-node, hosts_by_node should have one entry with self-binding
assert len(created_instance.hosts_by_node) == 1
assert node_id in created_instance.hosts_by_node
assert len(created_instance.hosts_by_node[node_id]) == 1
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
assert created_instance.ephemeral_port > 0
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)

View File

@@ -1,5 +1,3 @@
from typing import Callable
import pytest
from loguru import logger
@@ -7,14 +5,20 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_connection,
create_node_profile,
create_rdma_connection,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -26,11 +30,6 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -38,8 +37,7 @@ def instance() -> Instance:
shard_assignments=ShardAssignments(
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
),
hosts_by_node={},
ephemeral_port=50000,
hosts=[],
)
@@ -50,8 +48,6 @@ def model_meta() -> ModelMetadata:
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=10,
supports_tensor=True,
)
@@ -77,34 +73,33 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(2))
topology.add_connection(node_id_c, node_id_a, create_connection(3))
# act
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
# assert
assert len(placements) == 1
@@ -130,23 +125,20 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_exact_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -157,23 +149,20 @@ def test_get_instance_placements_one_node_exact_fit(
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1001 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -184,25 +173,22 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_not_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
cic = place_instance_command(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {})
place_instance(cic, topology, {}, profiles)
def test_get_transition_events_no_change(instance: Instance):
@@ -247,179 +233,103 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
def test_placement_prioritizes_leaf_cycle_with_less_memory(
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
# arrange
topology = Topology()
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_meta.storage_size = Memory.from_bytes(1000)
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# Daisy chain topology
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_a, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(1))
topology.add_connection(node_id_c, node_id_b, create_connection(1))
topology.add_connection(node_id_c, node_id_d, create_connection(1))
topology.add_connection(node_id_d, node_id_c, create_connection(1))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
logger.info(list(topology.list_connections()))
cic = place_instance_command(
model_meta=model_meta,
)
# Act
placements = place_instance(cic, topology, {})
# act
placements = place_instance(cic, topology, {}, profiles)
# Assert: D-E-F cycle should be selected as it has more total memory
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
instance = list(placements.values())[0]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(node_id_c, node_id_d)
)
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="192.168.1.100",
)
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
topology.add_connection(node_a, node_b, create_rdma_connection(3))
topology.add_connection(node_b, node_c, create_rdma_connection(4))
topology.add_connection(node_c, node_a, create_rdma_connection(5))
topology.add_connection(node_b, node_a, create_rdma_connection(3))
topology.add_connection(node_c, node_b, create_rdma_connection(4))
topology.add_connection(node_a, node_c, create_rdma_connection(5))
topology.add_connection(node_a, node_b, ethernet_conn)
topology.add_connection(node_b, node_c, ethernet_conn)
topology.add_connection(node_c, node_a, ethernet_conn)
topology.add_connection(node_a, node_c, ethernet_conn)
topology.add_connection(node_b, node_a, ethernet_conn)
topology.add_connection(node_c, node_b, ethernet_conn)
cic = PlaceInstance(
sharding=Sharding.Tensor,
@@ -429,7 +339,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
min_nodes=1,
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -437,10 +347,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.jaccl_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.ibv_devices
matrix = instance.jaccl_devices
assert len(matrix) == 3
for i in range(3):
@@ -449,15 +359,15 @@ def test_tensor_rdma_backend_connectivity_matrix(
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3

View File

@@ -1,56 +1,48 @@
from typing import Callable
import pytest
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_connection, create_node_profile
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_cycles_by_memory():
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
topology.add_node(node1)
topology.add_node(node2)
topology.add_node(node1_id)
topology.add_node(node2_id)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
connection1 = create_connection(1)
connection2 = create_connection(2)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
# assert
assert len(filtered_cycles) == 1
@@ -58,64 +50,65 @@ def test_filter_cycles_by_memory(
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_cycles_by_insufficient_memory():
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
topology.add_node(node1)
topology.add_node(node2)
topology.add_node(node1_id)
topology.add_node(node2_id)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
connection1 = create_connection(1)
connection2 = create_connection(2)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), Memory.from_kb(2001)
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_multiple_cycles_by_memory():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
# assert
assert len(filtered_cycles) == 1
@@ -127,31 +120,38 @@ def test_filter_multiple_cycles_by_memory(
}
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_get_smallest_cycles():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
# act
smallest_cycles = get_smallest_cycles(topology.get_cycles())
smallest_cycles = get_smallest_cycles(cycles)
# assert
assert len(smallest_cycles) == 1
@@ -168,9 +168,6 @@ def test_get_smallest_cycles(
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -179,29 +176,37 @@ def test_get_shard_assignments(
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_c_id, create_connection(2))
topology.add_connection(node_c_id, node_a_id, create_connection(3))
topology.add_connection(node_b_id, node_a_id, create_connection(4))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
selected_cycle = cycles[0]
# act
@@ -230,28 +235,21 @@ def test_get_shard_assignments(
)
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
def test_get_hosts_from_subgraph():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -259,108 +257,47 @@ def test_get_hosts_from_subgraph(
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
Host(ip=("169.254.0.2"), port=1234),
Host(ip=("169.254.0.3"), port=1234),
Host(ip=("169.254.0.4"), port=1234),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
def test_get_mlx_jaccl_coordinators():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(1)
conn_b_a = create_connection(2)
conn_b_c = create_connection(3)
conn_c_b = create_connection(4)
conn_c_a = create_connection(5)
conn_a_c = create_connection(6)
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_b)
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
topology.add_connection(node_a_id, node_b_id, conn_a_b)
topology.add_connection(node_b_id, node_a_id, conn_b_a)
topology.add_connection(node_b_id, node_c_id, conn_b_c)
topology.add_connection(node_c_id, node_b_id, conn_c_b)
topology.add_connection(node_c_id, node_a_id, conn_c_a)
topology.add_connection(node_a_id, node_c_id, conn_a_c)
# act
coordinators = get_mlx_jaccl_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
node_a_id, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -389,11 +326,11 @@ def test_get_mlx_jaccl_coordinators(
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
"node_b should use the IP from conn_b_a"
)
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
"node_c should use the IP from conn_c_a"
)

View File

@@ -1,13 +1,14 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
from exo.shared.types.topology import SocketConnection
@pytest.fixture
@@ -16,20 +17,15 @@ def topology() -> Topology:
@pytest.fixture
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
def connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile.from_bytes(
memory_profile = MemoryUsage.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -43,162 +39,85 @@ def node_profile() -> NodePerformanceProfile:
)
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
def test_add_node(topology: Topology):
# arrange
node_id = NodeId()
# act
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
topology.add_node(node_id)
# assert
data = topology.get_node_profile(node_id)
assert data == node_profile
assert topology.node_is_leaf(node_id)
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_add_connection(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
# act
data = topology.get_connection_profile(connection)
data = list(conn for _, _, conn in topology.list_connections())
# assert
assert data == connection.connection_profile
assert data == [connection]
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_remove_connection_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, connection: SocketConnection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
# act
topology.remove_connection(connection)
topology.remove_connection(node_a, node_b, connection)
# assert
assert topology.get_connection_profile(connection) is None
assert list(topology.get_all_connections_between(node_a, node_b)) == []
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
# act
topology.remove_node(connection.local_node_id)
topology.remove_node(node_b)
# assert
assert topology.get_node_profile(connection.local_node_id) is None
assert list(topology.out_edges(node_a)) == []
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_list_nodes(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}
assert all(isinstance(node, NodeId) for node in nodes)
assert {node for node in nodes} == {node_a, node_b}

View File

@@ -11,10 +11,8 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeGatheredInfo,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -27,13 +25,23 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import NodeInfo
from exo.shared.types.topology import RDMAConnection
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacTBConnections,
MacTBIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
@@ -47,16 +55,12 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -188,7 +192,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.copy(state.topology)
topology = copy.deepcopy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -196,8 +200,12 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -205,103 +213,69 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
# TODO: should be broken up into individual events instead of this monster
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
# TODO: makes me slightly sad
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
case MacTBIdentifiers():
profile.tb_interfaces = info.idents
case MacTBConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
(
conn_map[tb_conn.sink_uuid][0],
RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
return state.model_copy(
update={
"node_profiles": new_profiles,
"topology": topology,
"last_seen": last_seen,
"topology": topology,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
topology = copy.deepcopy(state.topology)
topology.add_connection(event.source, event.sink, event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.sink, event.source, event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -38,6 +38,7 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"

View File

@@ -24,6 +24,8 @@ class _InterceptHandler(logging.Handler):
except ValueError:
level = record.levelno
return
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())

View File

@@ -51,8 +51,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
@@ -66,8 +64,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# "deepseek-v3.2": ModelCard(
@@ -139,8 +135,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
@@ -154,8 +148,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
@@ -170,38 +162,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
@@ -215,8 +175,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
@@ -231,8 +189,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
@@ -246,8 +202,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
@@ -261,8 +215,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
@@ -277,8 +229,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
@@ -292,8 +242,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
@@ -307,8 +255,20 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# phi-3
"phi-3-mini": ModelCard(
short_id="phi-3-mini",
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
name="Phi 3 Mini 128k (4-bit)",
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
pretty_name="Phi 3 Mini 128k (4-bit)",
storage_size=Memory.from_mb(2099),
n_layers=32,
),
),
# qwen3
@@ -323,8 +283,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
@@ -338,8 +296,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
@@ -353,8 +309,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
@@ -368,68 +322,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
@@ -443,8 +335,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
@@ -458,8 +348,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
@@ -473,8 +361,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
@@ -488,84 +374,77 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
# granite
"granite-3.3-2b": ModelCard(
short_id="granite-3.3-2b",
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
name="Granite 3.3 2B (FP16)",
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
pretty_name="Granite 3.3 2B (FP16)",
storage_size=Memory.from_mb(4951),
n_layers=40,
),
),
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# Needs to be quantized g32 or g16.
"glm-4.5-air-8bit": ModelCard(
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
# "granite-3.3-8b": ModelCard(
# short_id="granite-3.3-8b",
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
# name="Granite 3.3 8B",
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# storage_size=Memory.from_kb(133_000_000),
# n_layers=88,
# hidden_size=12288,
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
# pretty_name="Granite 3.3 8B",
# storage_size=Memory.from_kb(15958720),
# n_layers=40,
# ),
# ),
# smol-lm
# "smol-lm-135m": ModelCard(
# short_id="smol-lm-135m",
# model_id="mlx-community/SmolLM-135M-4bit",
# name="Smol LM 135M",
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
# pretty_name="Smol LM 135M",
# storage_size=Memory.from_kb(73940),
# n_layers=30,
# ),
# ),
# gpt-oss
# "gpt-oss-120b-MXFP4-Q8": ModelCard(
# short_id="gpt-oss-120b-MXFP4-Q8",
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
# name="GPT-OSS 120B (MXFP4-Q8, MLX)",
# description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
# pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
# storage_size=Memory.from_kb(68_996_301),
# n_layers=36,
# hidden_size=2880,
# supports_tensor=True,
# ),
# ),
# "gpt-oss-20b-4bit": ModelCard(
# short_id="gpt-oss-20b-4bit",
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
# name="GPT-OSS 20B (MXFP4-Q4, MLX)",
# description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
# pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
# storage_size=Memory.from_kb(11_744_051),
# n_layers=24,
# hidden_size=2880,
# supports_tensor=True,
# ),
# ),

View File

@@ -6,7 +6,6 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
@@ -26,7 +25,6 @@ class ConfigData(BaseModel):
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
@@ -108,19 +106,10 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
pretty_name=model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -36,8 +36,6 @@ def get_pipeline_shard_metadata(
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,
supports_tensor=True,
),
device_rank=device_rank,
world_size=world_size,

View File

@@ -39,7 +39,4 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection
from exo.shared.types.topology import SocketConnection
def test_state_serialization_roundtrip() -> None:
@@ -11,14 +11,12 @@ def test_state_serialization_roundtrip() -> None:
node_a = NodeId("node-a")
node_b = NodeId("node-b")
connection = Connection(
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
connection = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
state.topology.add_connection(connection)
state.topology.add_connection(node_a, node_b, connection)
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)

View File

@@ -1,203 +1,219 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
class TopologySnapshot(BaseModel):
nodes: list[NodeInfo]
connections: list[Connection]
nodes: Sequence[NodeId]
connections: Mapping[
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
]
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
model_config = ConfigDict(frozen=True, extra="forbid")
@dataclass
class Topology:
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
# the _graph can be used as a int -> NodeId map.
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
nodes=list(self.list_nodes()), connections=self.map_connections()
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node in snapshot.nodes:
for node_id in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node)
topology.add_node(node_id)
for connection in snapshot.connections:
topology.add_connection(connection)
for source in snapshot.connections:
for sink in snapshot.connections[source]:
for conn in snapshot.connections[source][sink]:
topology.add_connection(source, sink, conn)
return topology
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
return
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
]
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
def out_edges(
self, node_id: NodeId
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
if node_id not in self._vertex_indices:
return []
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
]
return (
(self._graph[nid], conn)
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
)
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._node_id_to_rx_id_map
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
return node_id in self._vertex_indices
def add_connection(
self,
connection: Connection,
source: NodeId,
sink: NodeId,
connection: SocketConnection | RDMAConnection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
if connection in self.get_all_connections_between(source, sink):
return
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
_ = self._graph.add_edge(src_id, sink_id, connection)
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
try:
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def list_connections(
self,
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
return (
(
self._graph[src_id],
self._graph[sink_id],
connection,
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._node_id_to_rx_id_map:
if node_id not in self._vertex_indices:
return
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
rx_idx = self._vertex_indices[node_id]
self._graph.remove_node(rx_idx)
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
del self._vertex_indices[node_id]
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
def replace_all_out_tb_connections(
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for sink, conn in new_connections:
self.add_connection(source, sink, conn)
def remove_connection(
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
) -> None:
if source not in self._vertex_indices or sink not in self._vertex_indices:
return
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[source], self._vertex_indices[sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == edge:
self._graph.remove_edge_from_index(conn_idx)
def get_cycles(self) -> list[list[NodeInfo]]:
def get_cycles(self) -> list[list[NodeId]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_cycles_tb(self) -> list[list[NodeInfo]]:
def get_cycles_tb(self) -> list[list[NodeId]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
tb_graph.add_edge(u, v, conn)
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]] = []
for cycle_idx in cycle_idxs:
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
topology = Topology()
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
for source, sink, connection in self.list_connections():
if source in node_ids and sink in node_ids:
topology.add_connection(source, sink, connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.types.common import CommandId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ModelId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -174,7 +174,6 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
class DeleteInstanceResponse(BaseModel):

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.topology import SocketConnection
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,25 +76,15 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
class NodePerformanceMeasured(BaseEvent):
# TODO: bikeshed this naem
class NodeGatheredInfo(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
class NodeDownloadProgress(BaseEvent):
@@ -107,11 +97,15 @@ class ChunkGenerated(BaseEvent):
class TopologyEdgeCreated(BaseEvent):
edge: Connection
source: NodeId
sink: NodeId
edge: SocketConnection
class TopologyEdgeDeleted(BaseEvent):
edge: Connection
source: NodeId
sink: NodeId
edge: SocketConnection
Event = (
@@ -125,10 +119,8 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeGatheredInfo
| NodeDownloadProgress
| ChunkGenerated
| TopologyEdgeCreated

View File

@@ -14,5 +14,3 @@ class ModelMetadata(CamelCaseModel):
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -1,10 +1,11 @@
import re
from typing import ClassVar
from pydantic import BaseModel, computed_field, field_validator
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,12 +1,14 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import TBIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryPerformanceProfile(CamelCaseModel):
class MemoryUsage(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -44,7 +46,6 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -53,15 +54,16 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[TBIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
class ConnectionProfile(CamelCaseModel):
throughput: float
latency: float
jitter: float
pass

View File

@@ -40,10 +40,6 @@ class LoadModel(BaseTask): # emitted by Worker
pass
class ConnectToGroup(BaseTask): # emitted by Worker
pass
class StartWarmup(BaseTask): # emitted by Worker
pass
@@ -61,11 +57,5 @@ class Shutdown(BaseTask): # emitted by Worker
Task = (
CreateRunner
| DownloadModel
| ConnectToGroup
| LoadModel
| StartWarmup
| ChatCompletion
| Shutdown
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
)

View File

@@ -0,0 +1,64 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class TBConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class TBIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
# Intentionally minimal, only collecting data we care about - there's a lot more
class TBReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str
class TBConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None
class TBConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None
device_name_key: str
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: TBReceptacleTag
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
if self.domain_uuid_key is None:
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
iface = f"rdma_{ifaces[tag]}"
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
def conn(self) -> TBConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
)
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
class TBConnectivity(BaseModel):
SPThunderboltDataType: list[TBConnectivityData]
@classmethod
async def gather(cls) -> list[TBConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType

View File

@@ -1,37 +1,32 @@
from exo.shared.types.common import NodeId
from enum import Enum
from loguru import logger
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import FrozenModel
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
def is_thunderbolt(self) -> bool:
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
logger.warning("duh")
return True
# TODO
class LinkType(str, Enum):
Thunderbolt = "Thunderbolt"
Ethernet = "Ethernet"
WiFi = "WiFi"
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -25,12 +25,11 @@ class BaseInstance(TaggedModel):
class MlxRingInstance(BaseInstance):
hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int
hosts: list[Host]
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]

View File

@@ -1,43 +0,0 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Callable
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
SystemPerformanceProfile,
)
class ResourceCollector(ABC):
@abstractmethod
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
class SystemResourceCollector(ResourceCollector):
async def collect(self) -> SystemPerformanceProfile: ...
class MemoryResourceCollector(ResourceCollector):
async def collect(self) -> MemoryPerformanceProfile: ...
class ResourceMonitor:
data_collectors: list[ResourceCollector]
effect_handlers: set[
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
]
async def _collect(
self,
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
tasks: list[
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
] = [collector.collect() for collector in self.data_collectors]
return await asyncio.gather(*tasks)
async def collect(self) -> None:
profiles = await self._collect()
for profile in profiles:
for effect_handler in self.effect_handlers:
effect_handler(profile)

View File

@@ -21,15 +21,7 @@ class BaseRunnerStatus(TaggedModel):
return isinstance(self, RunnerRunning)
class RunnerIdle(BaseRunnerStatus):
pass
class RunnerConnecting(BaseRunnerStatus):
pass
class RunnerConnected(BaseRunnerStatus):
class RunnerWaitingForModel(BaseRunnerStatus):
pass
@@ -62,9 +54,7 @@ class RunnerFailed(BaseRunnerStatus):
RunnerStatus = (
RunnerIdle
| RunnerConnecting
| RunnerConnected
RunnerWaitingForModel
| RunnerLoading
| RunnerLoaded
| RunnerWarmingUp

View File

@@ -0,0 +1,231 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacTBIdentifiers(TaggedModel):
idents: Sequence[TBIdentifier]
class MacTBConnections(TaggedModel):
conns: Sequence[TBConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
# TODO
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacTBIdentifiers
| MacTBConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
if IS_DARWIN:
tg.start_soon(self._monitor_system_profiler)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await TBConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacTBIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacTBConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -0,0 +1,70 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -0,0 +1,56 @@
import socket
from collections.abc import Mapping
from ipaddress import ip_address
from anyio import create_task_group, to_thread
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
# TODO: ref. api port
async def check_reachability(
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout
try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
except socket.gaierror:
# seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}")
return
finally:
sock.close()
if result == 0:
if target_node_id not in out:
out[target_node_id] = set()
out[target_node_id].add(target_ip)
async def check_reachable(
our_node_id: NodeId,
topology: Topology,
profiles: Mapping[NodeId, NodePerformanceProfile],
) -> Mapping[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
our_profile = profiles.get(our_node_id, None)
if our_profile is None:
return {}
our_interfaces = our_profile.network_interfaces
async with create_task_group() as tg:
for node_id in topology.list_nodes():
if node_id not in profiles or node_id == our_node_id:
continue
for iface in profiles[node_id].network_interfaces:
if ip_address(iface.ip_address).is_loopback:
# Definitely a loopback address
continue
if iface in our_interfaces:
# Skip duplicates with our own interfaces
continue
tg.start_soon(check_reachability, iface.ip_address, node_id, reachable)
return reachable

View File

@@ -19,11 +19,20 @@ class CamelCaseModel(BaseModel):
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
strict=True,
)
class FrozenModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
strict=True,
frozen=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):

View File

@@ -28,9 +28,8 @@ def bar(send: MpSender[str]):
send.close()
# not async, just want the fail_after
@pytest.mark.anyio
async def test_channel_setup():
async def test_channel_ipc():
with fail_after(0.5):
s, r = mp_channel[str]()
p1 = mp.Process(target=foo, args=(r,))

View File

@@ -95,15 +95,7 @@ def extract_layer_num(tensor_name: str) -> int | None:
def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list[str]:
default_patterns = set(
[
"*.json",
"*.py",
"tokenizer.model",
"tiktoken.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt", "*.jinja"]
)
shard_specific_patterns: set[str] = set()
if weight_map:

View File

@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from copy import copy
from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
@@ -13,7 +12,7 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
class ShardDownloader(ABC):
@abstractmethod
async def ensure_shard(
@@ -44,7 +43,34 @@ class ShardDownloader(ABC):
Yields:
tuple[Path, RepoDownloadProgress]: The path and progress of a shard download.
"""
yield (Path("/tmp/noop_shard"), NOOP_DOWNLOAD_PROGRESS)
yield (
Path("/tmp/noop_shard"),
RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
),
)
@abstractmethod
async def get_shard_download_status_for_shard(
@@ -68,41 +94,46 @@ class NoopShardDownloader(ShardDownloader):
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
yield (
Path("/tmp/noop_shard"),
NOOP_DOWNLOAD_PROGRESS,
RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
),
)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata
) -> RepoDownloadProgress:
dp = copy(NOOP_DOWNLOAD_PROGRESS)
dp.shard = shard
return dp
NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)
return RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=shard,
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)

View File

@@ -10,6 +10,7 @@ KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = 8
TEMPERATURE: float = 1.0
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -13,6 +13,7 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
TEMPERATURE,
TRUST_REMOTE_CODE,
)
@@ -20,8 +21,6 @@ try:
from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
import contextlib
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.utils import load_model
@@ -49,7 +48,6 @@ from exo.worker.engines.mlx.auto_parallel import (
)
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
# Needed for 8 bit model
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
@@ -69,7 +67,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
def mx_barrier(group: Group | None = None):
def mx_barrier(group: mx.distributed.Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
@@ -79,7 +77,7 @@ def mx_barrier(group: Group | None = None):
)
def broadcast_from_zero(value: int, group: Group | None = None):
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
if group is None:
return value
@@ -101,97 +99,85 @@ class HostList(RootModel[list[str]]):
def mlx_distributed_init(
bound_instance: BoundInstance,
) -> Group:
) -> mx.distributed.Group:
"""
Initialize MLX distributed.
Initialize the MLX distributed
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
coordination_file = None
try:
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts=hosts):
hostfile = f"./hosts_{rank}.json"
hosts_json = HostList.from_hosts(hosts).model_dump_json()
with open(coordination_file, "w") as f:
_ = f.write(hosts_json)
with open(hostfile, "w") as f:
_ = f.write(hosts_json)
logger.info(
f"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}"
)
logger.info(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
os.environ["MLX_HOSTFILE"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
group = mx.distributed.init(backend="ring", strict=True)
os.environ["MLX_HOSTFILE"] = hostfile
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
ibv_devices_json = json.dumps(ibv_devices)
case MlxJacclInstance(
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
jaccl_devices_json = json.dumps(jaccl_devices)
with open(coordination_file, "w") as f:
_ = f.write(ibv_devices_json)
with open(devices_file, "w") as f:
_ = f.write(jaccl_devices_json)
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"rank {rank} MLX_JACCL_DEVICES: {jaccl_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
return group
def initialize_mlx(
bound_instance: BoundInstance,
) -> Group:
# should we unseed it?
# TODO: pass in seed from params
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
"""
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
"""
mx.random.seed(42)
assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (
"Tried to initialize mlx for a single node instance"
)
return mlx_distributed_init(bound_instance)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
# TODO: pass temperature
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
logger.info("Created a sampler")
if group is None:
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
# model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
# )
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
group = mlx_distributed_init(bound_instance)
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
@@ -201,12 +187,14 @@ def load_mlx_items(
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
logger.debug(model)
return cast(Model, model), tokenizer, sampler
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
group: mx.distributed.Group,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)

View File

@@ -16,15 +16,13 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -33,7 +31,7 @@ from exo.shared.types.tasks import (
Task,
TaskStatus,
)
from exo.shared.types.topology import Connection
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -44,14 +42,14 @@ from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
class Worker:
@@ -85,7 +83,7 @@ class Worker:
self.state: State = State()
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
self._tg: TaskGroup = create_task_group()
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -97,37 +95,13 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -140,6 +114,17 @@ class Worker:
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -159,7 +144,6 @@ class Worker:
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
assert self._tg
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
@@ -228,7 +212,7 @@ class Worker:
)
)
else:
await self.event_sender.send(
self.event_sender.send_nowait(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
@@ -248,8 +232,7 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
if self._tg:
self._tg.cancel_scope.cancel()
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
@@ -266,24 +249,24 @@ class Worker:
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
),
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
),
)
async def _nack_request(self, since_idx: int) -> None:
@@ -332,7 +315,6 @@ class Worker:
event_sender=self.event_sender.clone(),
)
self.runners[task.bound_instance.bound_runner_id] = runner
assert self._tg
self._tg.start_soon(runner.run)
return runner
@@ -391,7 +373,6 @@ class Worker:
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
@@ -414,33 +395,35 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology, self.node_id)
conns = await check_reachable(
self.node_id, self.state.topology, self.state.node_profiles
)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
edge = SocketConnection(
# nonsense multiaddr
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
await self.event_sender.send(
TopologyEdgeCreated(
source=self.node_id, sink=nid, edge=edge
)
)
for nid, conn in self.state.topology.out_edges(self.node_id):
if (
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
if not isinstance(conn, SocketConnection):
continue
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
nid, set()
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await self.event_sender.send(
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
)
await anyio.sleep(10)

View File

@@ -5,7 +5,6 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
LoadModel,
@@ -15,23 +14,17 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerFailed,
RunnerId,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
@@ -55,7 +48,6 @@ def plan(
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
@@ -114,11 +106,9 @@ def _model_needs_download(
download_status: Mapping[ShardMetadata, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
if isinstance(runner.status, RunnerIdle) and (
not isinstance(
download_status.get(runner.bound_instance.bound_shard, None),
(DownloadOngoing, DownloadCompleted),
)
if (
isinstance(runner.status, RunnerWaitingForModel)
and runner.bound_instance.bound_shard not in download_status
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
@@ -127,54 +117,14 @@ def _model_needs_download(
)
def _init_distributed_backend(
""" --- TODO!
def _init_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
):
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
if is_single_node_instance:
continue
runner_is_idle = isinstance(runner.status, RunnerIdle)
all_runners_connecting = all(
isinstance(
all_runners.get(global_runner_id),
(RunnerConnecting, RunnerIdle),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if not (runner_is_idle and all_runners_connecting):
continue
runner_id = runner.bound_instance.bound_runner_id
shard = runner.bound_instance.bound_shard
device_rank = shard.device_rank
world_size = shard.world_size
assert device_rank < world_size
assert device_rank >= 0
accepting_ranks = device_rank < world_size - 1
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id
)
if not (accepting_ranks or connecting_rank_ready):
continue
return ConnectToGroup(instance_id=instance.instance_id)
return None
) -> LoadModel | None:
for runner in runner.values()
pass
"""
def _load_model(
@@ -186,33 +136,31 @@ def _load_model(
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
all_local_downloads_complete = all(
all_downloads_complete_local = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner
for nid, rid in shard_assignments.node_to_runner.items()
)
if not all_local_downloads_complete:
continue
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
is_runner_waiting = isinstance(runner.status, RunnerConnected)
all_ready_for_model = all(
all_runners_expecting_model = all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerLoaded),
all_runners.get(global_runner_id),
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if is_runner_waiting and all_ready_for_model:
if (
all_downloads_complete_local
and runner_is_waiting
and all_runners_expecting_model
):
return LoadModel(instance_id=instance.instance_id)
return None
@@ -235,9 +183,8 @@ def _ready_to_warmup(
assert device_rank < world_size
assert device_rank >= 0
# TODO: Ensure these align with MLX distributeds expectations.
# Rank < n-1
accepting_ranks_ready = device_rank < world_size - 1 and all(
# Rank != n-1
accepting_ranks_ready = device_rank != world_size - 1 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -274,8 +221,6 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard

View File

@@ -2,13 +2,16 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import MpReceiver, MpSender
logger: "loguru.Logger" = loguru.logger
logger: "loguru.Logger"
if os.getenv("EXO_TESTS") == "1":
logger = loguru.logger
def entrypoint(
@@ -19,7 +22,7 @@ def entrypoint(
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
and len(bound_instance.instance.jaccl_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
@@ -27,23 +30,6 @@ def entrypoint(
logger = _logger
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except Exception as e:
logger.opt(exception=e).warning(
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
)
event_sender.send(
RunnerStatusUpdated(
runner_id=bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=str(e)),
)
)
finally:
event_sender.close()
task_receiver.close()
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")
main(bound_instance, event_sender, task_receiver)

View File

@@ -11,7 +11,6 @@ from exo.shared.types.events import (
)
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
@@ -23,23 +22,20 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerFailed,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger
@@ -67,10 +63,9 @@ def main(
model = None
tokenizer = None
sampler = None
group = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
current_status: RunnerStatus = RunnerWaitingForModel()
logger.info("runner waiting for model")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
@@ -83,26 +78,9 @@ def main(
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
case LoadModel() if isinstance(
current_status, (RunnerWaitingForModel, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
logger.info("runner connected")
current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected)
and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
@@ -111,12 +89,15 @@ def main(
)
)
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
model, tokenizer, sampler = initialize_mlx(bound_instance)
current_status = RunnerLoaded()
logger.info("runner loaded")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
@@ -142,6 +123,11 @@ def main(
)
current_status = RunnerReady()
logger.info("runner ready")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerReady()
)
)
case ChatCompletion(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
@@ -186,6 +172,11 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerReady()
)
)
case Shutdown():
logger.info("runner shutting down")
event_sender.send(
@@ -195,19 +186,12 @@ def main(
)
break
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
raise ValueError("Received task outside of state machine")
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
)

View File

@@ -19,8 +19,8 @@ from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerStatus,
RunnerWaitingForModel,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
@@ -41,7 +41,7 @@ class RunnerSupervisor:
_event_sender: Sender[Event]
# err_path: str
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
status: RunnerStatus = field(default_factory=RunnerWaitingForModel, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
@classmethod

View File

@@ -24,9 +24,3 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
SHUTDOWN_TASK_ID = TaskId("shutdown")
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
INITIALIZATION_TASK_ID = TaskId("initialisation")
LOAD_TASK_ID = TaskId("load")
WARMUP_TASK_ID = TaskId("warmup")

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from dataclasses import dataclass
from exo.shared.types.common import NodeId
@@ -16,7 +14,6 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
# Runner supervisor without multiprocessing logic.
@dataclass(frozen=True)
class FakeRunnerSupervisor:
bound_instance: BoundInstance
@@ -38,8 +35,6 @@ def get_pipeline_shard_metadata(
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,
supports_tensor=False,
),
device_rank=device_rank,
world_size=world_size,
@@ -72,27 +67,5 @@ def get_mlx_ring_instance(
shard_assignments=get_shard_assignments(
model_id, node_to_runner, runner_to_shard
),
hosts_by_node={},
ephemeral_port=50000,
)
def get_bound_mlx_ring_instance(
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
) -> BoundInstance:
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=2)
other_shard = get_pipeline_shard_metadata(
model_id=model_id, device_rank=1, world_size=2
)
instance = get_mlx_ring_instance(
instance_id=instance_id,
model_id=model_id,
node_to_runner={
node_id: runner_id,
NodeId("other_node"): RunnerId("other_runner"),
},
runner_to_shard={runner_id: shard, RunnerId("other_runner"): other_shard},
)
return BoundInstance(
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
hosts=[],
)

View File

@@ -4,8 +4,7 @@ from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerIdle,
RunnerWaitingForModel,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
@@ -39,11 +38,13 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ShardMetadata, DownloadProgress] = {}
@@ -81,15 +82,15 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerConnected()
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerWaitingForModel(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
@@ -132,11 +133,13 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
# Local status claims the shard is downloaded already
local_download_status = {
@@ -180,14 +183,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerConnected()
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerWaitingForModel(),
}
# Only NODE_A's shard is recorded as downloaded globally
@@ -210,22 +213,3 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
)
assert result is None
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
], # NODE_B has no downloads completed yet
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is not None

View File

@@ -5,9 +5,9 @@ from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerReady,
RunnerRunning,
RunnerWaitingForModel,
)
from exo.worker.tests.constants import (
COMMAND_1_ID,
@@ -99,7 +99,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerReady(),
RUNNER_2_ID: RunnerIdle(),
RUNNER_2_ID: RunnerWaitingForModel(),
}
task = ChatCompletion(

View File

@@ -2,9 +2,8 @@ import exo.worker.plan as plan_mod
from exo.shared.types.tasks import StartWarmup
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.worker.tests.constants import (
@@ -22,9 +21,9 @@ from exo.worker.tests.unittests.conftest import (
)
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
"""
For non-final device_rank shards, StartWarmup should be emitted when all
For non-zero device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
@@ -37,13 +36,13 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
@@ -51,10 +50,10 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
}
result = plan_mod.plan(
node_id=NODE_A,
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
@@ -129,7 +128,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerIdle(),
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerLoaded(),
}
@@ -150,9 +149,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -163,7 +159,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -178,93 +173,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
"""
For connecting rank (device_rank == world_size - 1), StartWarmup should
only be emitted once all the other runners are already warming up.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWarmingUp(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():
"""
Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoading(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
@@ -276,46 +184,3 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
)
assert result is None
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == world_size - 1) should not start warmup
until all other ranks are already WarmingUp.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None

View File

@@ -1,208 +0,0 @@
# Check tasks are complete before runner is ever ready.
from collections.abc import Iterable
from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from ...constants import (
CHAT_COMPLETION_TASK_ID,
COMMAND_1_ID,
INITIALIZATION_TASK_ID,
INSTANCE_1_ID,
LOAD_TASK_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
SHUTDOWN_TASK_ID,
WARMUP_TASK_ID,
)
from ..conftest import get_bound_mlx_ring_instance
def make_nothin[T, U, V](res: T) -> Callable[[], T]:
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
nothin = make_nothin(None)
INIT_TASK = ConnectToGroup(
task_id=INITIALIZATION_TASK_ID,
instance_id=INSTANCE_1_ID,
)
LOAD_TASK = LoadModel(
task_id=LOAD_TASK_ID,
instance_id=INSTANCE_1_ID,
)
WARMUP_TASK = StartWarmup(
task_id=WARMUP_TASK_ID,
instance_id=INSTANCE_1_ID,
)
SHUTDOWN_TASK = Shutdown(
task_id=SHUTDOWN_TASK_ID,
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
CHAT_PARAMS = ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_tokens=4,
temperature=0.0,
)
CHAT_TASK = ChatCompletion(
task_id=CHAT_COMPLETION_TASK_ID,
command_id=COMMAND_1_ID,
task_params=CHAT_PARAMS,
instance_id=INSTANCE_1_ID,
)
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
for test_event, true_event in zip(test_events, true_events, strict=True):
test_event.event_id = true_event.event_id
assert test_event == true_event, f"{test_event} != {true_event}"
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NODE_A,
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
with task_sender, event_receiver:
for t in tasks:
task_sender.send(t)
# worst monkeypatch known to man
# this is some c++ nonsense
event_sender.close = nothin
event_sender.join = nothin
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_receiver.collect()
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID,
chunk=TokenChunk(
idx=0,
model=MODEL_A_ID,
text="hi",
token_id=0,
finish_reason="stop",
),
)
assert_events_equal(
events,
[
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
),
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
],
)

View File

@@ -1 +0,0 @@
# TODO:

View File

@@ -1,6 +0,0 @@
from .profile import start_polling_memory_metrics, start_polling_node_metrics
__all__ = [
"start_polling_node_metrics",
"start_polling_memory_metrics",
]

View File

@@ -1,76 +0,0 @@
import http.client
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = response.read().decode("utf-8").strip()
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
return NodeId(body) or None
except OSError:
return None
finally:
connection.close()
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
)
return reachable

View File

@@ -1,114 +0,0 @@
import asyncio
import os
import platform
from typing import Any, Callable, Coroutine
import anyio
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from .macmon import (
MacMonError,
Metrics,
)
from .macmon import (
get_metrics_async as macmon_get_metrics_async,
)
from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
)
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
if platform.system().lower() == "darwin":
return await macmon_get_metrics_async()
def get_memory_profile() -> MemoryPerformanceProfile:
"""Construct a MemoryPerformanceProfile using psutil"""
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
async def start_polling_memory_metrics(
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 0.5,
) -> None:
"""Continuously poll and emit memory-only metrics at a faster cadence.
Parameters
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
- poll_interval_s: interval between polls
"""
while True:
try:
mem = get_memory_profile()
await callback(mem)
except MacMonError as e:
logger.opt(exception=e).error("Memory Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
)
)
except asyncio.TimeoutError:
logger.warning(
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
)
except MacMonError as e:
logger.opt(exception=e).error("Resource Monitor encountered error")
return
finally:
await anyio.sleep(poll_interval_s)

41
uv.lock generated
View File

@@ -334,10 +334,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -376,11 +374,9 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = ">=0.30.1" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.30.1" },
{ name = "mlx", specifier = ">=0.30.1" },
{ name = "mlx-lm", specifier = ">=0.28.3" },
{ name = "networkx", specifier = ">=3.5" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "protobuf", specifier = ">=6.32.0" },
{ name = "psutil", specifier = ">=7.0.0" },
{ name = "pydantic", specifier = ">=2.11.7" },
@@ -805,20 +801,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
]
[package.optional-dependencies]
cpu = [
{ name = "mlx-cpu", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx-cpu"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/64/51/32903727a68a61e972383e28a775c1f5e5f0628552c85cbc6103d68c0dc4/mlx_cpu-0.30.1-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:3f5dc2e4d0849181f8253508bb6a0854250483fc63d43ac79ec614b19824b172", size = 8992394, upload-time = "2025-12-18T00:16:13.696Z" },
{ url = "https://files.pythonhosted.org/packages/0c/74/69c21bb907f3c4064881ab0653029c939ae15fc4e63a5301ef8643cb1d68/mlx_cpu-0.30.1-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:c9ea6992d8c001e1123dfd3b4d4405ff576c787eec52656ad405e3d033a8be60", size = 10553055, upload-time = "2025-12-18T00:16:16.104Z" },
]
[[package]]
name = "mlx-lm"
version = "0.28.3"
@@ -964,27 +946,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e2/c1/6dba12fdf68b02a21ac411c9df19afa66bed2540f467150ca64d246b463d/numpy-2.3.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e1708fac43ef8b419c975926ce1eaf793b0c13b7356cfab6ab0dc34c0a02ac0f", size = 18652691, upload-time = "2025-10-15T16:17:46.247Z" },
]
[[package]]
name = "openai-harmony"
version = "0.0.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3e/92/2d038d096f29179c7c9571b431f9e739f87a487121901725e23fe338dd9d/openai_harmony-0.0.8.tar.gz", hash = "sha256:6e43f98e6c242fa2de6f8ea12eab24af63fa2ed3e89c06341fb9d92632c5cbdf", size = 284777, upload-time = "2025-11-05T19:07:06.727Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/c6/2502f416d46be3ec08bb66d696cccffb57781a499e3ff2e4d7c174af4e8f/openai_harmony-0.0.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:029ec25ca74abe48fdb58eb9fdd2a8c1618581fc33ce8e5653f8a1ffbfbd9326", size = 2627806, upload-time = "2025-11-05T19:06:57.063Z" },
{ url = "https://files.pythonhosted.org/packages/d3/d2/ce6953ca87db9cae3e775024184da7d1c5cb88cead19a2d75b42f00a959c/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4f709815924ec325b9a890e6ab2bbb0ceec8e319a4e257328eb752cf36b2efc", size = 2948463, upload-time = "2025-11-05T19:06:48.17Z" },
{ url = "https://files.pythonhosted.org/packages/fa/4c/b553c9651662d6ce102ca7f3629d268b23df1abe5841e24bed81e8a8e949/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cfcfd963b50a41fc656c84d3440ca6eecdccd6c552158ce790b8f2e33dfb5a9", size = 2704083, upload-time = "2025-11-05T19:06:50.205Z" },
{ url = "https://files.pythonhosted.org/packages/9b/af/4eec8f9ab9c27bcdb444460c72cf43011d176fc44c79d6e113094ca1e152/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a3a16972aa1cee38ea958470cd04ac9a2d5ac38fdcf77ab686611246220c158", size = 2959765, upload-time = "2025-11-05T19:06:53.62Z" },
{ url = "https://files.pythonhosted.org/packages/11/3c/33f3374e4624e0e776f6b13b73c45a7ead7f9c4529f8369ed5bfcaa30cac/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4d5cfa168e74d08f8ba6d58a7e49bc7daef4d58951ec69b66b0d56f4927a68d", size = 3427031, upload-time = "2025-11-05T19:06:51.829Z" },
{ url = "https://files.pythonhosted.org/packages/25/3f/1a192b93bb47c6b44cd98ba8cc1d3d2a9308f1bb700c3017e6352da11bda/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c007d277218a50db8839e599ed78e0fffe5130f614c3f6d93ae257f282071a29", size = 2953260, upload-time = "2025-11-05T19:06:55.406Z" },
{ url = "https://files.pythonhosted.org/packages/5b/f8/93b582cad3531797c3db7c2db5400fd841538ccddfd9f5e3df61be99a630/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8565d4f5a0638da1bffde29832ed63c9e695c558611053add3b2dc0b56c92dbc", size = 3127044, upload-time = "2025-11-05T19:06:59.553Z" },
{ url = "https://files.pythonhosted.org/packages/1d/10/4327dbf87f75ae813405fd9a9b4a5cde63d506ffed0a096a440a4cabd89c/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:cbaa3bda75ef0d8836e1f8cc84af62f971b1d756d740efc95c38c3e04c0bfde2", size = 2932931, upload-time = "2025-11-05T19:07:01.437Z" },
{ url = "https://files.pythonhosted.org/packages/8a/c8/1774eec4f6f360ef57618fb8f52e3d3af245b2491bd0297513aa09eec04b/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:772922a9bd24e133950fad71eb1550836f415a88e8c77870e12d0c3bd688ddc2", size = 2996140, upload-time = "2025-11-05T19:07:03.438Z" },
{ url = "https://files.pythonhosted.org/packages/60/c3/3d1e01e2dba517a91760e4a03e4f20ffc75039a6fe584d0e6f9b5c78fd15/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:007b0476a1f331f8130783f901f1da6f5a7057af1a4891f1b6a31dec364189b5", size = 3205080, upload-time = "2025-11-05T19:07:05.078Z" },
]
[[package]]
name = "packaging"
version = "25.0"