mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-30 09:40:46 -05:00
Compare commits
9 Commits
linux-cpu-
...
optimize-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40cbecb5c4 | ||
|
|
fea42473dd | ||
|
|
ca7adcc2a8 | ||
|
|
9d9e24f969 | ||
|
|
b5d424b658 | ||
|
|
b465134012 | ||
|
|
eabdcab978 | ||
|
|
8e9332d6a7 | ||
|
|
4b65d5f896 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,6 +7,8 @@ digest.txt
|
||||
# nix
|
||||
.direnv/
|
||||
|
||||
# IDEA (PyCharm)
|
||||
.idea
|
||||
|
||||
# xcode / macos
|
||||
*.xcuserstate
|
||||
|
||||
83
README.md
83
README.md
@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
|
||||
|
||||
There are two ways to run exo:
|
||||
|
||||
### Run from Source (Mac & Linux)
|
||||
### Run from Source (macOS)
|
||||
|
||||
**Prerequisites:**
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
@@ -98,6 +98,62 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
|
||||
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
|
||||
|
||||
**Installation methods:**
|
||||
|
||||
**Option 1: Using system package manager (Ubuntu/Debian example):**
|
||||
```bash
|
||||
# Install Node.js and npm
|
||||
sudo apt update
|
||||
sudo apt install nodejs npm
|
||||
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Option 2: Using Homebrew on Linux (if preferred):**
|
||||
```bash
|
||||
# Install Homebrew on Linux
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
|
||||
# Install dependencies
|
||||
brew install uv node
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Note:** The `macmon` package is macOS-only and not required for Linux.
|
||||
|
||||
Clone the repo, build the dashboard, and run exo:
|
||||
|
||||
```bash
|
||||
# Clone exo
|
||||
git clone https://github.com/exo-explore/exo
|
||||
|
||||
# Build dashboard
|
||||
cd exo/dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo
|
||||
uv run exo
|
||||
```
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
|
||||
|
||||
### macOS App
|
||||
|
||||
exo ships a macOS app that runs in the background on your Mac.
|
||||
@@ -112,6 +168,29 @@ The app will ask for permission to modify system settings and install a new Netw
|
||||
|
||||
---
|
||||
|
||||
### Enabling RDMA on macOS
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
1. Shut down your Mac.
|
||||
2. Hold down the power button for 10 seconds until the boot menu appears.
|
||||
3. Select "Options" to enter Recovery mode.
|
||||
4. When the Recovery UI appears, open the Terminal from the Utilities menu.
|
||||
5. In the Terminal, type:
|
||||
```
|
||||
rdma_ctl enable
|
||||
```
|
||||
and press Enter.
|
||||
6. Reboot your Mac.
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
|
||||
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
|
||||
|
||||
1
TODO.md
1
TODO.md
@@ -19,7 +19,6 @@
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Validate RDMA connections with ibv_devinfo in the info gatherer
|
||||
|
||||
Potential refactors:
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ struct ContentView: View {
|
||||
|
||||
private var topologySection: some View {
|
||||
Group {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(), !topology.nodes.isEmpty {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
|
||||
TopologyMiniView(topology: topology)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,7 +82,6 @@ struct BugReportService {
|
||||
}
|
||||
|
||||
private func loadCredentials() throws -> AWSConfig {
|
||||
// These credentials are write-only and necessary to receive bug reports from users
|
||||
return AWSConfig(
|
||||
accessKey: "AKIAYEKP5EMXTOBYDGHX",
|
||||
secretKey: "Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE",
|
||||
|
||||
@@ -7,6 +7,7 @@ final class ClusterStateService: ObservableObject {
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var lastActionMessage: String?
|
||||
@Published private(set) var modelOptions: [ModelOption] = []
|
||||
@Published private(set) var localNodeId: String?
|
||||
|
||||
private var timer: Timer?
|
||||
private let decoder: JSONDecoder
|
||||
@@ -29,6 +30,7 @@ final class ClusterStateService: ObservableObject {
|
||||
func startPolling(interval: TimeInterval = 0.5) {
|
||||
stopPolling()
|
||||
Task {
|
||||
await fetchLocalNodeId()
|
||||
await fetchModels()
|
||||
await fetchSnapshot()
|
||||
}
|
||||
@@ -46,9 +48,31 @@ final class ClusterStateService: ObservableObject {
|
||||
latestSnapshot = nil
|
||||
lastError = nil
|
||||
lastActionMessage = nil
|
||||
localNodeId = nil
|
||||
}
|
||||
|
||||
private func fetchLocalNodeId() async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
return
|
||||
}
|
||||
if let nodeId = try? decoder.decode(String.self, from: data) {
|
||||
localNodeId = nodeId
|
||||
}
|
||||
} catch {
|
||||
// Silently ignore - localNodeId will remain nil and retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchSnapshot() async {
|
||||
// Retry fetching local node ID if not yet set
|
||||
if localNodeId == nil {
|
||||
await fetchLocalNodeId()
|
||||
}
|
||||
do {
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
|
||||
@@ -85,7 +85,7 @@ struct TopologyViewModel {
|
||||
}
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel() -> TopologyViewModel? {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
|
||||
guard !allNodes.isEmpty else { return nil }
|
||||
@@ -105,6 +105,11 @@ extension ClusterState {
|
||||
orderedNodes = allNodes
|
||||
}
|
||||
|
||||
// Rotate so the local node (from /node_id API) is first
|
||||
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
|
||||
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
|
||||
}
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
|
||||
@@ -112,10 +117,7 @@ extension ClusterState {
|
||||
} ?? []
|
||||
let edges = Set(edgesArray)
|
||||
|
||||
let topologyRootId = topology?.nodes.first?.nodeId
|
||||
let currentId = orderedNodes.first(where: { $0.id == topologyRootId })?.id ?? orderedNodes.first?.id
|
||||
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: currentId)
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
48
dashboard/package-lock.json
generated
48
dashboard/package-lock.json
generated
@@ -9,6 +9,8 @@
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -861,7 +863,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -901,7 +902,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,7 +1518,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1528,7 +1527,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1941,7 +1939,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2254,6 +2251,31 @@
|
||||
"jiti": "lib/jiti-cli.mjs"
|
||||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.27",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.27.tgz",
|
||||
"integrity": "sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==",
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
"https://github.com/sponsors/katex"
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"commander": "^8.3.0"
|
||||
},
|
||||
"bin": {
|
||||
"katex": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/katex/node_modules/commander": {
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
|
||||
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/kleur": {
|
||||
"version": "4.1.5",
|
||||
"resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz",
|
||||
@@ -2540,6 +2562,18 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.5.5"
|
||||
}
|
||||
},
|
||||
"node_modules/marked": {
|
||||
"version": "17.0.1",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz",
|
||||
"integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==",
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"marked": "bin/marked.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 20"
|
||||
}
|
||||
},
|
||||
"node_modules/mode-watcher": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz",
|
||||
@@ -2612,7 +2646,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2800,7 +2833,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2945,7 +2977,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2967,7 +2998,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -27,7 +27,8 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -198,8 +198,10 @@
|
||||
stroke: oklch(0.85 0.18 85 / 0.4);
|
||||
stroke-width: 1.5px;
|
||||
stroke-dasharray: 8, 8;
|
||||
animation: flowAnimation 1s linear infinite;
|
||||
animation: flowAnimation 1.5s linear infinite;
|
||||
filter: drop-shadow(0 0 3px oklch(0.85 0.18 85 / 0.5));
|
||||
/* GPU optimization - hint to browser this element will animate */
|
||||
will-change: stroke-dashoffset;
|
||||
}
|
||||
|
||||
.graph-link-active {
|
||||
@@ -208,6 +210,24 @@
|
||||
filter: drop-shadow(0 0 6px oklch(0.85 0.18 85 / 0.8));
|
||||
}
|
||||
|
||||
/* Reduce motion for users who prefer it - also saves GPU */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.graph-link {
|
||||
animation: none;
|
||||
}
|
||||
|
||||
.shooting-star {
|
||||
animation: none;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.status-pulse,
|
||||
.cursor-blink,
|
||||
.animate-pulse {
|
||||
animation: none;
|
||||
}
|
||||
}
|
||||
|
||||
/* CRT Screen effect for topology */
|
||||
.crt-screen {
|
||||
position: relative;
|
||||
@@ -266,13 +286,15 @@ input:focus, textarea:focus {
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
/* Shooting Stars Animation */
|
||||
/* Shooting Stars Animation - GPU optimized */
|
||||
.shooting-stars {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
overflow: hidden;
|
||||
pointer-events: none;
|
||||
z-index: 0;
|
||||
/* Only render when visible */
|
||||
content-visibility: auto;
|
||||
}
|
||||
|
||||
.shooting-star {
|
||||
@@ -285,6 +307,9 @@ input:focus, textarea:focus {
|
||||
animation: shootingStar var(--duration, 3s) linear infinite;
|
||||
animation-delay: var(--delay, 0s);
|
||||
opacity: 0;
|
||||
/* GPU optimization */
|
||||
will-change: transform, opacity;
|
||||
transform: translateZ(0);
|
||||
}
|
||||
|
||||
.shooting-star::before {
|
||||
@@ -320,3 +345,13 @@ input:focus, textarea:focus {
|
||||
transform: translate(400px, 400px);
|
||||
}
|
||||
}
|
||||
|
||||
/* Pause animations when page is hidden to save resources */
|
||||
:root:has(body[data-page-hidden="true"]) {
|
||||
.shooting-star,
|
||||
.graph-link,
|
||||
.status-pulse,
|
||||
.cursor-blink {
|
||||
animation-play-state: paused;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,89 +8,80 @@
|
||||
regenerateLastResponse
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import { tick, onDestroy } from 'svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
|
||||
// Ref for scroll anchor at bottom
|
||||
let scrollAnchorRef: HTMLDivElement | undefined = $state();
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
let showScrollButton = $state(false);
|
||||
let lastMessageCount = 0;
|
||||
let containerRef: HTMLDivElement | undefined = $state();
|
||||
|
||||
// Scroll management
|
||||
const SCROLL_BOTTOM_THRESHOLD = 120;
|
||||
let autoScrollEnabled = true;
|
||||
let currentScrollEl: HTMLElement | null = null;
|
||||
|
||||
function resolveScrollElement(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
let node: HTMLElement | null = scrollAnchorRef?.parentElement as HTMLElement | null;
|
||||
while (node) {
|
||||
const isScrollable = node.scrollHeight > node.clientHeight + 1;
|
||||
if (isScrollable) return node;
|
||||
node = node.parentElement;
|
||||
function getScrollContainer(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
return containerRef?.parentElement ?? null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (!currentScrollEl) return;
|
||||
const distanceFromBottom = currentScrollEl.scrollHeight - currentScrollEl.scrollTop - currentScrollEl.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
autoScrollEnabled = isNearBottom;
|
||||
}
|
||||
|
||||
function attachScrollListener() {
|
||||
const nextEl = resolveScrollElement();
|
||||
if (currentScrollEl === nextEl) return;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
function isNearBottom(el: HTMLElement): boolean {
|
||||
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
}
|
||||
currentScrollEl = nextEl;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.addEventListener('scroll', handleScroll);
|
||||
// Initialize state based on current position
|
||||
handleScroll();
|
||||
}
|
||||
}
|
||||
|
||||
onDestroy(() => {
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
// Re-evaluate scroll container if prop changes or after mount
|
||||
scrollParent;
|
||||
attachScrollListener();
|
||||
});
|
||||
|
||||
// Auto-scroll to bottom when messages change or response updates, but only if user is near bottom
|
||||
$effect(() => {
|
||||
// Track these values to trigger effect
|
||||
const _ = messageList.length;
|
||||
const __ = response;
|
||||
const ___ = loading;
|
||||
|
||||
tick().then(() => {
|
||||
const el = currentScrollEl ?? resolveScrollElement();
|
||||
if (!el || !scrollAnchorRef) return;
|
||||
const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
if (autoScrollEnabled || isNearBottom) {
|
||||
scrollAnchorRef.scrollIntoView({ behavior: 'smooth', block: 'end' });
|
||||
autoScrollEnabled = true;
|
||||
function scrollToBottom() {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
}
|
||||
}
|
||||
|
||||
function updateScrollButtonVisibility() {
|
||||
const el = getScrollContainer();
|
||||
if (!el) return;
|
||||
showScrollButton = !isNearBottom(el);
|
||||
}
|
||||
|
||||
// Attach scroll listener
|
||||
$effect(() => {
|
||||
const el = scrollParent ?? containerRef?.parentElement;
|
||||
if (!el) return;
|
||||
|
||||
el.addEventListener('scroll', updateScrollButtonVisibility, { passive: true });
|
||||
// Initial check
|
||||
updateScrollButtonVisibility();
|
||||
return () => el.removeEventListener('scroll', updateScrollButtonVisibility);
|
||||
});
|
||||
|
||||
// Auto-scroll when user sends a new message
|
||||
$effect(() => {
|
||||
const count = messageList.length;
|
||||
if (count > lastMessageCount) {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
requestAnimationFrame(() => {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
});
|
||||
}
|
||||
}
|
||||
lastMessageCount = count;
|
||||
});
|
||||
|
||||
// Update scroll button visibility when content changes
|
||||
$effect(() => {
|
||||
// Track response to trigger re-check during streaming
|
||||
const _ = response;
|
||||
|
||||
// Small delay to let DOM update
|
||||
requestAnimationFrame(() => updateScrollButtonVisibility());
|
||||
});
|
||||
});
|
||||
|
||||
// Edit state
|
||||
let editingMessageId = $state<string | null>(null);
|
||||
@@ -231,7 +222,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
<div class="group flex {message.role === 'user' ? 'justify-end' : 'justify-start'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'max-w-[95%] sm:max-w-[85%]'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'w-full max-w-[98%] sm:max-w-[95%]'}">
|
||||
{#if message.role === 'assistant'}
|
||||
<!-- Assistant message header -->
|
||||
<div class="flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2">
|
||||
@@ -305,7 +296,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{:else}
|
||||
<div class="{message.role === 'user'
|
||||
? 'command-panel rounded-lg rounded-tr-sm inline-block'
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 inline-block'}">
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full'}">
|
||||
|
||||
{#if message.role === 'user'}
|
||||
<!-- User message styling -->
|
||||
@@ -331,7 +322,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
|
||||
{#if message.content}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
<div class="text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -360,7 +351,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
<span>Thinking...</span>
|
||||
</span>
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60">
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4">
|
||||
{isThinkingExpanded(message.id) ? 'HIDE' : 'SHOW'}
|
||||
</span>
|
||||
</button>
|
||||
@@ -374,8 +365,8 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content || (loading ? response : '')}
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
@@ -457,6 +448,20 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Scroll anchor for auto-scroll -->
|
||||
<div bind:this={scrollAnchorRef}></div>
|
||||
<!-- Invisible element for container reference -->
|
||||
<div bind:this={containerRef}></div>
|
||||
|
||||
<!-- Scroll to bottom button -->
|
||||
{#if showScrollButton}
|
||||
<button
|
||||
type="button"
|
||||
onclick={scrollToBottom}
|
||||
class="sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10"
|
||||
title="Scroll to bottom"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,9 @@ import {
|
||||
clearChat,
|
||||
instances,
|
||||
debugMode,
|
||||
toggleDebugMode
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode
|
||||
} from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -23,6 +25,7 @@ import {
|
||||
const activeId = $derived(activeConversationId());
|
||||
const instanceData = $derived(instances());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
|
||||
let searchQuery = $state('');
|
||||
let editingId = $state<string | null>(null);
|
||||
@@ -424,6 +427,19 @@ const debugEnabled = $derived(debugMode());
|
||||
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
|
||||
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title="Toggle topology only mode"
|
||||
>
|
||||
<svg class="w-4 h-4 {topologyOnlyEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
export let showHome = true;
|
||||
export let onHome: (() => void) | null = null;
|
||||
export let showSidebarToggle = false;
|
||||
export let sidebarVisible = true;
|
||||
export let onToggleSidebar: (() => void) | null = null;
|
||||
|
||||
function handleHome(): void {
|
||||
if (onHome) {
|
||||
@@ -14,13 +17,38 @@
|
||||
window.location.hash = '/';
|
||||
}
|
||||
}
|
||||
|
||||
function handleToggleSidebar(): void {
|
||||
if (onToggleSidebar) {
|
||||
onToggleSidebar();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<header class="relative z-20 flex items-center justify-center px-6 pt-8 pb-4 bg-exo-dark-gray">
|
||||
<!-- Left: Sidebar Toggle -->
|
||||
{#if showSidebarToggle}
|
||||
<div class="absolute left-6 top-1/2 -translate-y-1/2">
|
||||
<button
|
||||
onclick={handleToggleSidebar}
|
||||
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title={sidebarVisible ? 'Hide sidebar' : 'Show sidebar'}
|
||||
>
|
||||
<svg class="w-5 h-5 {sidebarVisible ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
{#if sidebarVisible}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M11 19l-7-7 7-7m8 14l-7-7 7-7" />
|
||||
{:else}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M13 5l7 7-7 7M5 5l7 7-7 7" />
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Center: Logo (clickable to go home) -->
|
||||
<button
|
||||
onclick={handleHome}
|
||||
class="hover:opacity-80 transition-opacity {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
class="bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
title={showHome ? 'Go to home' : ''}
|
||||
disabled={!showHome}
|
||||
>
|
||||
|
||||
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
@@ -0,0 +1,451 @@
|
||||
<script lang="ts">
|
||||
import { marked } from 'marked';
|
||||
import hljs from 'highlight.js';
|
||||
import katex from 'katex';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { content, class: className = '' }: Props = $props();
|
||||
|
||||
let containerRef = $state<HTMLDivElement>();
|
||||
let processedHtml = $state('');
|
||||
|
||||
// Configure marked with syntax highlighting
|
||||
marked.setOptions({
|
||||
gfm: true,
|
||||
breaks: true
|
||||
});
|
||||
|
||||
// Custom renderer for code blocks
|
||||
const renderer = new marked.Renderer();
|
||||
|
||||
renderer.code = function ({ text, lang }: { text: string; lang?: string }) {
|
||||
const language = lang && hljs.getLanguage(lang) ? lang : 'plaintext';
|
||||
const highlighted = hljs.highlight(text, { language }).value;
|
||||
const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
|
||||
|
||||
return `
|
||||
<div class="code-block-wrapper">
|
||||
<div class="code-block-header">
|
||||
<span class="code-language">${language}</span>
|
||||
<button type="button" class="copy-code-btn" data-code="${encodeURIComponent(text)}" title="Copy code">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<pre><code class="hljs language-${language}" data-code-id="${codeId}">${highlighted}</code></pre>
|
||||
</div>
|
||||
`;
|
||||
};
|
||||
|
||||
// Inline code
|
||||
renderer.codespan = function ({ text }: { text: string }) {
|
||||
return `<code class="inline-code">${text}</code>`;
|
||||
};
|
||||
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
function processMarkdown(text: string): string {
|
||||
try {
|
||||
// Preprocess LaTeX notation
|
||||
const preprocessed = preprocessLaTeX(text);
|
||||
// Parse markdown
|
||||
let html = marked.parse(preprocessed) as string;
|
||||
// Render math expressions
|
||||
html = renderMath(html);
|
||||
return html;
|
||||
} catch (error) {
|
||||
console.error('Markdown processing error:', error);
|
||||
return text.replace(/\n/g, '<br>');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedCode = target.getAttribute('data-code');
|
||||
if (!encodedCode) return;
|
||||
|
||||
const code = decodeURIComponent(encodedCode);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(code);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (content) {
|
||||
processedHtml = processMarkdown(content);
|
||||
} else {
|
||||
processedHtml = '';
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (containerRef && processedHtml) {
|
||||
setupCopyButtons();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="markdown-content {className}">
|
||||
{@html processedHtml}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.markdown-content {
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
/* Paragraphs */
|
||||
.markdown-content :global(p) {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Headers */
|
||||
.markdown-content :global(h1) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h2) {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h3) {
|
||||
font-size: 1.125rem;
|
||||
font-weight: 600;
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(h4),
|
||||
.markdown-content :global(h5),
|
||||
.markdown-content :global(h6) {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
margin: 0.75rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Bold and italic */
|
||||
.markdown-content :global(strong) {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(em) {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Inline code */
|
||||
.markdown-content :global(.inline-code) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.25rem;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
.markdown-content :global(a) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
text-decoration: underline;
|
||||
text-underline-offset: 2px;
|
||||
}
|
||||
|
||||
.markdown-content :global(a:hover) {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Lists */
|
||||
.markdown-content :global(ul) {
|
||||
list-style-type: disc;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(ol) {
|
||||
list-style-type: decimal;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li) {
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li::marker) {
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
}
|
||||
|
||||
/* Blockquotes */
|
||||
.markdown-content :global(blockquote) {
|
||||
border-left: 3px solid var(--exo-yellow, #ffd700);
|
||||
padding: 0.5rem 1rem;
|
||||
margin: 1rem 0;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-radius: 0 0.25rem 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.markdown-content :global(table) {
|
||||
width: 100%;
|
||||
margin: 1rem 0;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(th) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(td) {
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
padding: 0.5rem;
|
||||
}
|
||||
|
||||
/* Horizontal rule */
|
||||
.markdown-content :global(hr) {
|
||||
border: none;
|
||||
border-top: 1px solid rgba(255, 255, 255, 0.1);
|
||||
margin: 1.5rem 0;
|
||||
}
|
||||
|
||||
/* Code block wrapper */
|
||||
.markdown-content :global(.code-block-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
background: rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.1);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-language) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
font-size: 0.7rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper pre) {
|
||||
margin: 0;
|
||||
padding: 1rem;
|
||||
overflow-x: auto;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper code) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.5;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Syntax highlighting - dark theme matching EXO style */
|
||||
.markdown-content :global(.hljs) {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-keyword),
|
||||
.markdown-content :global(.hljs-selector-tag),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-section),
|
||||
.markdown-content :global(.hljs-link) {
|
||||
color: #c084fc;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-string),
|
||||
.markdown-content :global(.hljs-title),
|
||||
.markdown-content :global(.hljs-name),
|
||||
.markdown-content :global(.hljs-type),
|
||||
.markdown-content :global(.hljs-attribute),
|
||||
.markdown-content :global(.hljs-symbol),
|
||||
.markdown-content :global(.hljs-bullet),
|
||||
.markdown-content :global(.hljs-addition),
|
||||
.markdown-content :global(.hljs-variable),
|
||||
.markdown-content :global(.hljs-template-tag),
|
||||
.markdown-content :global(.hljs-template-variable) {
|
||||
color: #fbbf24;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-comment),
|
||||
.markdown-content :global(.hljs-quote),
|
||||
.markdown-content :global(.hljs-deletion),
|
||||
.markdown-content :global(.hljs-meta) {
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-number),
|
||||
.markdown-content :global(.hljs-regexp),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-built_in) {
|
||||
color: #34d399;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-function),
|
||||
.markdown-content :global(.hljs-class .hljs-title) {
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
margin: 1rem 0;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error) {
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview } from '$lib/stores/app.svelte';
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview, TopologyEdge } from '$lib/stores/app.svelte';
|
||||
import { debugMode, topologyData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
model: { id: string; name?: string; storage_size_megabytes?: number };
|
||||
@@ -206,12 +207,8 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
const centerY = topoHeight / 2;
|
||||
const radius = numNodes === 1 ? 0 : numNodes === 2 ? 45 : Math.min(topoWidth, topoHeight) * 0.32;
|
||||
|
||||
// Use API preview data if available
|
||||
// Only use API preview data - no local estimation
|
||||
const hasApiPreview = apiPreview !== null && apiPreview.error === null && apiPreview.memory_delta_by_node !== null;
|
||||
const canFit = hasApiPreview ? true : (() => {
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return totalAvailable >= estimatedMemory;
|
||||
})();
|
||||
const error = apiPreview?.error ?? null;
|
||||
|
||||
let placementNodes: Array<{
|
||||
@@ -232,135 +229,140 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
modelFillHeight: number;
|
||||
}> = [];
|
||||
|
||||
if (hasApiPreview && apiPreview.memory_delta_by_node) {
|
||||
// Use API placement data
|
||||
const memoryDelta = apiPreview.memory_delta_by_node;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else if (apiPreview?.error) {
|
||||
// API returned an error - model can't fit, show all nodes as unused
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: 0,
|
||||
currentPercent,
|
||||
newPercent: currentPercent,
|
||||
isUsed: false,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: 0
|
||||
};
|
||||
});
|
||||
} else {
|
||||
// Fallback: local estimation based on sharding strategy
|
||||
const memoryNeeded = estimatedMemory;
|
||||
// Use API placement data directly
|
||||
const memoryDelta = apiPreview?.memory_delta_by_node ?? {};
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
if (sharding === 'Pipeline') {
|
||||
const memoryPerNode = memoryNeeded / numNodes;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + memoryPerNode) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: memoryPerNode,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed: true,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else {
|
||||
let remaining = memoryNeeded;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const allocated = Math.min(remaining, n.availableGB);
|
||||
remaining -= allocated;
|
||||
const isUsed = allocated > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + allocated) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: allocated,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return { nodes: placementNodes, canFit: hasApiPreview || canFit, totalAvailable, topoWidth, topoHeight, error };
|
||||
return { nodes: placementNodes, canFit: hasApiPreview, totalAvailable, topoWidth, topoHeight, error };
|
||||
});
|
||||
|
||||
const canFit = $derived(apiPreview ? apiPreview.error === null : placementPreview().canFit);
|
||||
const placementError = $derived(apiPreview?.error ?? null);
|
||||
const nodeCount = $derived(nodeList().length);
|
||||
const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, ''));
|
||||
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === 'MlxIbv' || runtime === 'MlxJaccl');
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
if (!ip || !topology?.nodes) return null;
|
||||
|
||||
// Strip port if present
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Check specified node first
|
||||
const node = topology.nodes[nodeId];
|
||||
if (node) {
|
||||
const match = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
// Fallback: check all nodes
|
||||
for (const [, otherNode] of Object.entries(topology.nodes)) {
|
||||
if (!otherNode) continue;
|
||||
const match = otherNode.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get directional arrow based on node positions
|
||||
function getArrow(fromNode: { x: number; y: number }, toNode: { x: number; y: number }): string {
|
||||
const dx = toNode.x - fromNode.x;
|
||||
const dy = toNode.y - fromNode.y;
|
||||
const absX = Math.abs(dx);
|
||||
const absY = Math.abs(dy);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dx > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dy > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dx > 0 && dy > 0) return '↘';
|
||||
if (dx > 0 && dy < 0) return '↗';
|
||||
if (dx < 0 && dy > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Get connection info for edges between two nodes
|
||||
// Returns exactly one connection per direction (A→B and B→A), preferring non-loopback
|
||||
function getConnectionInfo(nodeId1: string, nodeId2: string): Array<{ ip: string; iface: string | null; from: string; to: string }> {
|
||||
if (!topology?.edges) return [];
|
||||
|
||||
// Collect candidates for each direction
|
||||
const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
|
||||
for (const edge of topology.edges) {
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
|
||||
|
||||
if (edge.source === nodeId1 && edge.target === nodeId2) {
|
||||
aToBCandidates.push({ ip, iface });
|
||||
} else if (edge.source === nodeId2 && edge.target === nodeId1) {
|
||||
bToACandidates.push({ ip, iface });
|
||||
}
|
||||
}
|
||||
|
||||
// Pick best (prefer non-loopback)
|
||||
const pickBest = (candidates: Array<{ ip: string; iface: string | null }>) => {
|
||||
if (candidates.length === 0) return null;
|
||||
return candidates.find(c => !c.ip.startsWith('127.')) || candidates[0];
|
||||
};
|
||||
|
||||
const result: Array<{ ip: string; iface: string | null; from: string; to: string }> = [];
|
||||
|
||||
const bestAtoB = pickBest(aToBCandidates);
|
||||
if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });
|
||||
|
||||
const bestBtoA = pickBest(bToACandidates);
|
||||
if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });
|
||||
|
||||
return result;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative group">
|
||||
@@ -453,6 +455,26 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
|
||||
<!-- Connection lines between nodes (if multiple) -->
|
||||
{#if preview.nodes.length > 1}
|
||||
{@const usedNodes = preview.nodes.filter(n => n.isUsed)}
|
||||
{@const nodePositions = Object.fromEntries(preview.nodes.map(n => [n.id, { x: n.x, y: n.y }]))}
|
||||
{@const allConnections = isDebugMode && usedNodes.length > 1 ? (() => {
|
||||
const conns: Array<{ ip: string; iface: string | null; from: string; to: string; midX: number; midY: number; arrow: string }> = [];
|
||||
for (let i = 0; i < usedNodes.length; i++) {
|
||||
for (let j = i + 1; j < usedNodes.length; j++) {
|
||||
const n1 = usedNodes[i];
|
||||
const n2 = usedNodes[j];
|
||||
const midX = (n1.x + n2.x) / 2;
|
||||
const midY = (n1.y + n2.y) / 2;
|
||||
for (const c of getConnectionInfo(n1.id, n2.id)) {
|
||||
const fromPos = nodePositions[c.from];
|
||||
const toPos = nodePositions[c.to];
|
||||
const arrow = fromPos && toPos ? getArrow(fromPos, toPos) : '→';
|
||||
conns.push({ ...c, midX, midY, arrow });
|
||||
}
|
||||
}
|
||||
}
|
||||
return conns;
|
||||
})() : []}
|
||||
{#each preview.nodes as node, i}
|
||||
{#each preview.nodes.slice(i + 1) as node2}
|
||||
<line
|
||||
@@ -464,6 +486,43 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
/>
|
||||
{/each}
|
||||
{/each}
|
||||
<!-- Debug: Show connection IPs/interfaces in corners -->
|
||||
{#if isDebugMode && allConnections.length > 0}
|
||||
{@const centerX = preview.topoWidth / 2}
|
||||
{@const centerY = preview.topoHeight / 2}
|
||||
{@const quadrants = {
|
||||
topLeft: allConnections.filter(c => c.midX < centerX && c.midY < centerY),
|
||||
topRight: allConnections.filter(c => c.midX >= centerX && c.midY < centerY),
|
||||
bottomLeft: allConnections.filter(c => c.midX < centerX && c.midY >= centerY),
|
||||
bottomRight: allConnections.filter(c => c.midX >= centerX && c.midY >= centerY)
|
||||
}}
|
||||
{@const padding = 4}
|
||||
{@const lineHeight = 8}
|
||||
<!-- Top Left -->
|
||||
{#each quadrants.topLeft as conn, idx}
|
||||
<text x={padding} y={padding + idx * lineHeight} text-anchor="start" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Top Right -->
|
||||
{#each quadrants.topRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={padding + idx * lineHeight} text-anchor="end" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Left -->
|
||||
{#each quadrants.bottomLeft as conn, idx}
|
||||
<text x={padding} y={preview.topoHeight - padding - (quadrants.bottomLeft.length - 1 - idx) * lineHeight} text-anchor="start" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Right -->
|
||||
{#each quadrants.bottomRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={preview.topoHeight - padding - (quadrants.bottomRight.length - 1 - idx) * lineHeight} text-anchor="end" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#each preview.nodes as node}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from 'svelte';
|
||||
import { onMount, onDestroy, tick } from 'svelte';
|
||||
import * as d3 from 'd3';
|
||||
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
|
||||
|
||||
@@ -12,11 +12,35 @@ import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.sv
|
||||
|
||||
let svgContainer: SVGSVGElement | undefined = $state();
|
||||
let resizeObserver: ResizeObserver | undefined;
|
||||
|
||||
// Optimization: Track last render state to avoid unnecessary re-renders
|
||||
let lastRenderHash = '';
|
||||
let lastHighlightedNodesHash = '';
|
||||
let lastDimensions = { width: 0, height: 0 };
|
||||
let isRendering = false;
|
||||
let pendingRender = false;
|
||||
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
|
||||
// Generate a hash of relevant data to detect actual changes
|
||||
function generateDataHash(topologyData: typeof data, minimized: boolean, debug: boolean): string {
|
||||
if (!topologyData) return 'null';
|
||||
const nodes = topologyData.nodes || {};
|
||||
const edges = topologyData.edges || [];
|
||||
|
||||
// Create a lightweight hash from key properties only
|
||||
const nodeHashes = Object.entries(nodes).map(([id, n]) => {
|
||||
const macmon = n.macmon_info;
|
||||
return `${id}:${n.friendly_name || ''}:${macmon?.memory?.ram_usage || 0}:${macmon?.memory?.ram_total || 0}:${macmon?.temp?.gpu_temp_avg || 0}:${macmon?.gpu_usage?.[1] || 0}:${macmon?.sys_power || 0}`;
|
||||
}).sort().join('|');
|
||||
|
||||
const edgeHash = edges.map(e => `${e.source}-${e.target}`).sort().join(',');
|
||||
|
||||
return `${nodeHashes}::${edgeHash}::${minimized}::${debug}`;
|
||||
}
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
@@ -24,19 +48,36 @@ function getNodeLabel(nodeId: string): string {
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
const node = data?.nodes?.[nodeId];
|
||||
if (!node) return { label: '?', missing: true };
|
||||
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: typeof data.nodes[string]): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return { label: matchFromInterfaces.name, missing: false };
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const mapped = node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return { label: mapped, missing: false };
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
if (otherResult) return { label: otherResult, missing: false };
|
||||
}
|
||||
|
||||
return { label: '?', missing: true };
|
||||
@@ -67,6 +108,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
return lines;
|
||||
}
|
||||
|
||||
|
||||
// Apple logo path for MacBook Pro screen
|
||||
const APPLE_LOGO_PATH = "M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z";
|
||||
const LOGO_NATIVE_WIDTH = 814;
|
||||
@@ -238,6 +280,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
@@ -314,110 +357,98 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('marker-end', 'url(#arrowhead)');
|
||||
}
|
||||
|
||||
// Collect debug labels for later positioning at edges
|
||||
if (debugEnabled && entry.connections.length > 0) {
|
||||
const maxBoxes = 6;
|
||||
const fontSize = isMinimized ? 8 : 9;
|
||||
const lineGap = 2;
|
||||
const labelOffsetOut = Math.max(140, minDimension * 0.38);
|
||||
const labelOffsetSide = isMinimized ? 16 : 20;
|
||||
const boxWidth = 170;
|
||||
const maxLineLen = 26;
|
||||
|
||||
const connections = entry.connections.slice(0, maxBoxes);
|
||||
if (entry.connections.length > maxBoxes) {
|
||||
const remaining = entry.connections.length - maxBoxes;
|
||||
connections.push({
|
||||
from: '',
|
||||
to: '',
|
||||
ip: `(+${remaining} more)`,
|
||||
ifaceLabel: '',
|
||||
missingIface: false
|
||||
});
|
||||
}
|
||||
|
||||
let dirX = mx - centerX;
|
||||
let dirY = my - centerY;
|
||||
const dirLen = Math.hypot(dirX, dirY);
|
||||
if (dirLen < 1) {
|
||||
dirX = -uy;
|
||||
dirY = ux;
|
||||
} else {
|
||||
dirX /= dirLen;
|
||||
dirY /= dirLen;
|
||||
}
|
||||
|
||||
const nx = -dirY;
|
||||
const ny = dirX;
|
||||
|
||||
const labelXRaw = mx + dirX * labelOffsetOut + nx * labelOffsetSide;
|
||||
const labelYRaw = my + dirY * labelOffsetOut + ny * labelOffsetSide;
|
||||
const clampPad = Math.min(120, minDimension * 0.12);
|
||||
const labelX = Math.max(clampPad, Math.min(width - clampPad, labelXRaw));
|
||||
const labelY = Math.max(clampPad, Math.min(height - clampPad, labelYRaw));
|
||||
|
||||
const labelGroup = debugLabelsGroup.append('g')
|
||||
.attr('transform', `translate(${labelX}, ${labelY})`);
|
||||
|
||||
const textGroup = labelGroup.append('g');
|
||||
|
||||
connections.forEach((conn, idx) => {
|
||||
const rawLines = conn.from && conn.to
|
||||
? [
|
||||
`${getNodeLabel(conn.from)}→${getNodeLabel(conn.to)}`,
|
||||
`${conn.ip}`,
|
||||
`${conn.ifaceLabel}`
|
||||
]
|
||||
: [conn.ip];
|
||||
|
||||
const wrapped = rawLines.flatMap(line => wrapLine(line, maxLineLen));
|
||||
|
||||
wrapped.forEach((line, lineIdx) => {
|
||||
textGroup.append('text')
|
||||
.attr('x', 0)
|
||||
.attr('y', (idx * (wrapped.length * (fontSize + lineGap))) + lineIdx * (fontSize + lineGap))
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('dominant-baseline', 'hanging')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.9)')
|
||||
.text(line);
|
||||
});
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
isTop,
|
||||
mx,
|
||||
my
|
||||
});
|
||||
|
||||
const bbox = textGroup.node()?.getBBox();
|
||||
if (bbox) {
|
||||
const paddedWidth = Math.max(boxWidth, bbox.width + 14);
|
||||
const boxHeight = bbox.height + 8;
|
||||
const boxMinX = labelX - paddedWidth / 2;
|
||||
const boxMaxX = labelX + paddedWidth / 2;
|
||||
const boxMinY = labelY + bbox.y - 4;
|
||||
const boxMaxY = boxMinY + boxHeight;
|
||||
|
||||
const clampPadDynamic = Math.min(140, minDimension * 0.18);
|
||||
let shiftX = 0;
|
||||
let shiftY = 0;
|
||||
if (boxMinX < clampPadDynamic) shiftX = clampPadDynamic - boxMinX;
|
||||
if (boxMaxX > width - clampPadDynamic) shiftX = (width - clampPadDynamic) - boxMaxX;
|
||||
if (boxMinY < clampPadDynamic) shiftY = clampPadDynamic - boxMinY;
|
||||
if (boxMaxY > height - clampPadDynamic) shiftY = (height - clampPadDynamic) - boxMaxY;
|
||||
|
||||
const finalX = labelX + shiftX;
|
||||
const finalY = labelY + shiftY;
|
||||
labelGroup.attr('transform', `translate(${finalX}, ${finalY})`);
|
||||
|
||||
labelGroup.insert('rect', 'g')
|
||||
.attr('x', -paddedWidth / 2)
|
||||
.attr('y', bbox.y - 4)
|
||||
.attr('width', paddedWidth)
|
||||
.attr('height', boxHeight)
|
||||
.attr('rx', 4)
|
||||
.attr('fill', 'rgba(0,0,0,0.75)')
|
||||
.attr('stroke', 'rgba(255,255,255,0.12)')
|
||||
.attr('stroke-width', 0.6);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Render debug labels at viewport edges/corners
|
||||
if (debugEdgeLabels && debugEdgeLabels.length > 0) {
|
||||
const fontSize = isMinimized ? 10 : 12;
|
||||
const lineHeight = fontSize + 4;
|
||||
const padding = 10;
|
||||
|
||||
// Helper to get arrow based on direction vector
|
||||
function getArrow(fromId: string, toId: string): string {
|
||||
const fromPos = positionById[fromId];
|
||||
const toPos = positionById[toId];
|
||||
if (!fromPos || !toPos) return '→';
|
||||
|
||||
const dirX = toPos.x - fromPos.x;
|
||||
const dirY = toPos.y - fromPos.y;
|
||||
const absX = Math.abs(dirX);
|
||||
const absY = Math.abs(dirY);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dirX > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dirY > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dirX > 0 && dirY > 0) return '↘';
|
||||
if (dirX > 0 && dirY < 0) return '↗';
|
||||
if (dirX < 0 && dirY > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
edges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
debugLabelsGroup.append('text')
|
||||
.attr('x', baseX)
|
||||
.attr('y', currentY)
|
||||
.attr('text-anchor', textAnchor)
|
||||
.attr('dominant-baseline', isTop ? 'hanging' : 'auto')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.85)')
|
||||
.text(label);
|
||||
currentY += isTop ? lineHeight : -lineHeight;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Draw nodes
|
||||
const nodesGroup = svg.append('g').attr('class', 'nodes-group');
|
||||
|
||||
@@ -925,16 +956,59 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (data) {
|
||||
// Throttled render function to prevent too-frequent updates
|
||||
function scheduleRender() {
|
||||
if (isRendering) {
|
||||
pendingRender = true;
|
||||
return;
|
||||
}
|
||||
|
||||
isRendering = true;
|
||||
requestAnimationFrame(() => {
|
||||
renderGraph();
|
||||
isRendering = false;
|
||||
|
||||
if (pendingRender) {
|
||||
pendingRender = false;
|
||||
scheduleRender();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!data || !svgContainer) return;
|
||||
|
||||
// Generate hash of current state
|
||||
const currentHash = generateDataHash(data, isMinimized, debugEnabled);
|
||||
const highlightHash = Array.from(highlightedNodes).sort().join(',');
|
||||
|
||||
// Get current dimensions
|
||||
const rect = svgContainer.getBoundingClientRect();
|
||||
const dimensionsChanged = rect.width !== lastDimensions.width || rect.height !== lastDimensions.height;
|
||||
|
||||
// Only re-render if something actually changed
|
||||
if (currentHash !== lastRenderHash || highlightHash !== lastHighlightedNodesHash || dimensionsChanged) {
|
||||
lastRenderHash = currentHash;
|
||||
lastHighlightedNodesHash = highlightHash;
|
||||
lastDimensions = { width: rect.width, height: rect.height };
|
||||
scheduleRender();
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
if (svgContainer) {
|
||||
// Use a debounced resize observer to prevent rapid re-renders
|
||||
let resizeTimeout: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
resizeObserver = new ResizeObserver(() => {
|
||||
renderGraph();
|
||||
if (resizeTimeout) clearTimeout(resizeTimeout);
|
||||
resizeTimeout = setTimeout(() => {
|
||||
const rect = svgContainer!.getBoundingClientRect();
|
||||
if (rect.width !== lastDimensions.width || rect.height !== lastDimensions.height) {
|
||||
lastDimensions = { width: rect.width, height: rect.height };
|
||||
scheduleRender();
|
||||
}
|
||||
}, 100);
|
||||
});
|
||||
resizeObserver.observe(svgContainer);
|
||||
}
|
||||
@@ -962,10 +1036,20 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
stroke-width: 1px;
|
||||
stroke-dasharray: 4, 4;
|
||||
opacity: 0.8;
|
||||
animation: flowAnimation 0.75s linear infinite;
|
||||
/* Slower animation = less GPU usage */
|
||||
animation: flowAnimation 2s linear infinite;
|
||||
/* GPU optimization */
|
||||
will-change: stroke-dashoffset;
|
||||
}
|
||||
@keyframes flowAnimation {
|
||||
from { stroke-dashoffset: 0; }
|
||||
to { stroke-dashoffset: -10; }
|
||||
}
|
||||
|
||||
/* Respect reduced motion preference */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
:global(.graph-link) {
|
||||
animation: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -4,4 +4,5 @@ export { default as ChatMessages } from './ChatMessages.svelte';
|
||||
export { default as ChatAttachments } from './ChatAttachments.svelte';
|
||||
export { default as ChatSidebar } from './ChatSidebar.svelte';
|
||||
export { default as ModelCard } from './ModelCard.svelte';
|
||||
export { default as MarkdownContent } from './MarkdownContent.svelte';
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ interface RawNodeProfile {
|
||||
|
||||
interface RawTopologyNode {
|
||||
nodeId: string;
|
||||
nodeProfile?: RawNodeProfile;
|
||||
nodeProfile: RawNodeProfile;
|
||||
}
|
||||
|
||||
interface RawTopologyConnection {
|
||||
@@ -105,13 +105,9 @@ interface RawTopologyConnection {
|
||||
sendBackMultiaddr?: { multiaddr?: string; address?: string; ip_address?: string } | string;
|
||||
}
|
||||
|
||||
// Connection can be an object or a tuple [source, target, metadata]
|
||||
type RawConnectionItem = RawTopologyConnection | [string, string, { sinkMultiaddr?: { ip_address?: string; address?: string } }?];
|
||||
|
||||
interface RawTopology {
|
||||
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
|
||||
nodes: (string | RawTopologyNode)[];
|
||||
connections?: RawConnectionItem[];
|
||||
nodes: RawTopologyNode[];
|
||||
connections?: RawTopologyConnection[];
|
||||
}
|
||||
|
||||
type RawNodeProfiles = Record<string, RawNodeProfile>;
|
||||
@@ -202,17 +198,9 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
|
||||
const nodes: Record<string, NodeInfo> = {};
|
||||
const edges: TopologyEdge[] = [];
|
||||
|
||||
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
|
||||
for (const node of raw.nodes || []) {
|
||||
// Determine the node ID - could be a string or an object with nodeId property
|
||||
const nodeId = typeof node === 'string' ? node : node.nodeId;
|
||||
if (!nodeId) continue;
|
||||
|
||||
// Get the profile - from the separate profiles map or from the node object itself
|
||||
const profileFromMap = profiles?.[nodeId];
|
||||
const profileFromNode = typeof node === 'object' ? node.nodeProfile : undefined;
|
||||
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
|
||||
|
||||
const mergedProfile = profiles?.[node.nodeId];
|
||||
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
|
||||
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
|
||||
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
|
||||
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
|
||||
@@ -250,7 +238,7 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
|
||||
}
|
||||
}
|
||||
|
||||
nodes[nodeId] = {
|
||||
nodes[node.nodeId] = {
|
||||
system_info: {
|
||||
model_id: profile?.modelId ?? 'Unknown',
|
||||
chip: profile?.chipId,
|
||||
@@ -272,34 +260,14 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
|
||||
};
|
||||
}
|
||||
|
||||
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
|
||||
for (const conn of raw.connections || []) {
|
||||
let localNodeId: string | undefined;
|
||||
let sendBackNodeId: string | undefined;
|
||||
let sendBackMultiaddr: { multiaddr?: string; address?: string; ip_address?: string } | string | undefined;
|
||||
|
||||
// Check if it's a tuple format [source, target, metadata]
|
||||
if (Array.isArray(conn)) {
|
||||
localNodeId = conn[0] as string;
|
||||
sendBackNodeId = conn[1] as string;
|
||||
const metadata = conn[2] as { sinkMultiaddr?: { ip_address?: string; address?: string } } | undefined;
|
||||
if (metadata?.sinkMultiaddr) {
|
||||
sendBackMultiaddr = metadata.sinkMultiaddr;
|
||||
}
|
||||
} else {
|
||||
// Object format with localNodeId/sendBackNodeId
|
||||
localNodeId = conn.localNodeId;
|
||||
sendBackNodeId = conn.sendBackNodeId;
|
||||
sendBackMultiaddr = conn.sendBackMultiaddr;
|
||||
}
|
||||
|
||||
if (!localNodeId || !sendBackNodeId) continue;
|
||||
if (localNodeId === sendBackNodeId) continue;
|
||||
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
|
||||
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
|
||||
if (conn.localNodeId === conn.sendBackNodeId) continue;
|
||||
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
|
||||
|
||||
let sendBackIp: string | undefined;
|
||||
if (sendBackMultiaddr) {
|
||||
const multi = sendBackMultiaddr;
|
||||
if (conn.sendBackMultiaddr) {
|
||||
const multi = conn.sendBackMultiaddr;
|
||||
if (typeof multi === 'string') {
|
||||
sendBackIp = extractIpFromMultiaddr(multi);
|
||||
} else {
|
||||
@@ -308,8 +276,8 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
|
||||
}
|
||||
|
||||
edges.push({
|
||||
source: localNodeId,
|
||||
target: sendBackNodeId,
|
||||
source: conn.localNodeId,
|
||||
target: conn.sendBackNodeId,
|
||||
sendBackIp
|
||||
});
|
||||
}
|
||||
@@ -329,6 +297,35 @@ function extractIpFromMultiaddr(ma?: string): string | undefined {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Deep comparison utility for preventing unnecessary state updates
|
||||
function shallowEqual(a: unknown, b: unknown): boolean {
|
||||
if (a === b) return true;
|
||||
if (a === null || b === null) return false;
|
||||
if (typeof a !== 'object' || typeof b !== 'object') return false;
|
||||
|
||||
const aObj = a as Record<string, unknown>;
|
||||
const bObj = b as Record<string, unknown>;
|
||||
const aKeys = Object.keys(aObj);
|
||||
const bKeys = Object.keys(bObj);
|
||||
|
||||
if (aKeys.length !== bKeys.length) return false;
|
||||
|
||||
for (const key of aKeys) {
|
||||
if (aObj[key] !== bObj[key]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Faster JSON comparison for complex nested objects
|
||||
function jsonEqual(a: unknown, b: unknown): boolean {
|
||||
if (a === b) return true;
|
||||
try {
|
||||
return JSON.stringify(a) === JSON.stringify(b);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
class AppStore {
|
||||
// Conversation state
|
||||
conversations = $state<Conversation[]>([]);
|
||||
@@ -359,19 +356,49 @@ class AppStore {
|
||||
isTopologyMinimized = $state(false);
|
||||
isSidebarOpen = $state(false); // Hidden by default, shown when in chat mode
|
||||
debugMode = $state(false);
|
||||
topologyOnlyMode = $state(false);
|
||||
chatSidebarVisible = $state(true); // Shown by default
|
||||
|
||||
// Visibility state - used to pause polling when tab is hidden
|
||||
private isPageVisible = true;
|
||||
|
||||
private fetchInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
|
||||
// Cache for comparison - prevents unnecessary reactivity
|
||||
private lastTopologyJson = '';
|
||||
private lastInstancesJson = '';
|
||||
private lastRunnersJson = '';
|
||||
private lastDownloadsJson = '';
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
this.startPolling();
|
||||
this.loadConversationsFromStorage();
|
||||
this.loadDebugModeFromStorage();
|
||||
this.loadTopologyOnlyModeFromStorage();
|
||||
this.loadChatSidebarVisibleFromStorage();
|
||||
this.setupVisibilityListener();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Listen for page visibility changes to pause polling when hidden
|
||||
*/
|
||||
private setupVisibilityListener() {
|
||||
if (typeof document === 'undefined') return;
|
||||
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
this.isPageVisible = document.visibilityState === 'visible';
|
||||
|
||||
if (this.isPageVisible) {
|
||||
// Resume polling when page becomes visible
|
||||
this.fetchState();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Load conversations from localStorage
|
||||
*/
|
||||
@@ -426,6 +453,44 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
private loadTopologyOnlyModeFromStorage() {
|
||||
try {
|
||||
const stored = localStorage.getItem('exo-topology-only-mode');
|
||||
if (stored !== null) {
|
||||
this.topologyOnlyMode = stored === 'true';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load topology only mode:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private saveTopologyOnlyModeToStorage() {
|
||||
try {
|
||||
localStorage.setItem('exo-topology-only-mode', this.topologyOnlyMode ? 'true' : 'false');
|
||||
} catch (error) {
|
||||
console.error('Failed to save topology only mode:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private loadChatSidebarVisibleFromStorage() {
|
||||
try {
|
||||
const stored = localStorage.getItem('exo-chat-sidebar-visible');
|
||||
if (stored !== null) {
|
||||
this.chatSidebarVisible = stored === 'true';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load chat sidebar visibility:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private saveChatSidebarVisibleToStorage() {
|
||||
try {
|
||||
localStorage.setItem('exo-chat-sidebar-visible', this.chatSidebarVisible ? 'true' : 'false');
|
||||
} catch (error) {
|
||||
console.error('Failed to save chat sidebar visibility:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new conversation
|
||||
*/
|
||||
@@ -730,9 +795,39 @@ class AppStore {
|
||||
this.saveDebugModeToStorage();
|
||||
}
|
||||
|
||||
getTopologyOnlyMode(): boolean {
|
||||
return this.topologyOnlyMode;
|
||||
}
|
||||
|
||||
setTopologyOnlyMode(enabled: boolean) {
|
||||
this.topologyOnlyMode = enabled;
|
||||
this.saveTopologyOnlyModeToStorage();
|
||||
}
|
||||
|
||||
toggleTopologyOnlyMode() {
|
||||
this.topologyOnlyMode = !this.topologyOnlyMode;
|
||||
this.saveTopologyOnlyModeToStorage();
|
||||
}
|
||||
|
||||
getChatSidebarVisible(): boolean {
|
||||
return this.chatSidebarVisible;
|
||||
}
|
||||
|
||||
setChatSidebarVisible(visible: boolean) {
|
||||
this.chatSidebarVisible = visible;
|
||||
this.saveChatSidebarVisibleToStorage();
|
||||
}
|
||||
|
||||
toggleChatSidebarVisible() {
|
||||
this.chatSidebarVisible = !this.chatSidebarVisible;
|
||||
this.saveChatSidebarVisibleToStorage();
|
||||
}
|
||||
|
||||
startPolling() {
|
||||
this.fetchState();
|
||||
this.fetchInterval = setInterval(() => this.fetchState(), 1000);
|
||||
// Poll every 2 seconds instead of 1 second - reduces CPU/GPU load by 50%
|
||||
// Data comparison ensures we only update when something actually changes
|
||||
this.fetchInterval = setInterval(() => this.fetchState(), 2000);
|
||||
}
|
||||
|
||||
stopPolling() {
|
||||
@@ -744,6 +839,9 @@ class AppStore {
|
||||
}
|
||||
|
||||
async fetchState() {
|
||||
// Skip polling when page is hidden to save resources
|
||||
if (!this.isPageVisible) return;
|
||||
|
||||
try {
|
||||
const response = await fetch('/state');
|
||||
if (!response.ok) {
|
||||
@@ -751,19 +849,44 @@ class AppStore {
|
||||
}
|
||||
const data: RawStateResponse = await response.json();
|
||||
|
||||
// Only update topology if it actually changed (prevents unnecessary D3 re-renders)
|
||||
if (data.topology) {
|
||||
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
|
||||
const newTopology = transformTopology(data.topology, data.nodeProfiles);
|
||||
const newTopologyJson = JSON.stringify(newTopology);
|
||||
if (newTopologyJson !== this.lastTopologyJson) {
|
||||
this.lastTopologyJson = newTopologyJson;
|
||||
this.topologyData = newTopology;
|
||||
}
|
||||
}
|
||||
|
||||
// Only update instances if changed
|
||||
if (data.instances) {
|
||||
this.instances = data.instances;
|
||||
this.refreshConversationModelFromInstances();
|
||||
const newInstancesJson = JSON.stringify(data.instances);
|
||||
if (newInstancesJson !== this.lastInstancesJson) {
|
||||
this.lastInstancesJson = newInstancesJson;
|
||||
this.instances = data.instances;
|
||||
this.refreshConversationModelFromInstances();
|
||||
}
|
||||
}
|
||||
|
||||
// Only update runners if changed
|
||||
if (data.runners) {
|
||||
this.runners = data.runners;
|
||||
const newRunnersJson = JSON.stringify(data.runners);
|
||||
if (newRunnersJson !== this.lastRunnersJson) {
|
||||
this.lastRunnersJson = newRunnersJson;
|
||||
this.runners = data.runners;
|
||||
}
|
||||
}
|
||||
|
||||
// Only update downloads if changed
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
const newDownloadsJson = JSON.stringify(data.downloads);
|
||||
if (newDownloadsJson !== this.lastDownloadsJson) {
|
||||
this.lastDownloadsJson = newDownloadsJson;
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
}
|
||||
|
||||
this.lastUpdate = Date.now();
|
||||
} catch (error) {
|
||||
console.error('Error fetching state:', error);
|
||||
@@ -920,8 +1043,6 @@ class AppStore {
|
||||
|
||||
if (lastUserIndex === -1) return;
|
||||
|
||||
const lastUserMessage = this.messages[lastUserIndex];
|
||||
|
||||
// Remove any messages after the user message
|
||||
this.messages = this.messages.slice(0, lastUserIndex + 1);
|
||||
|
||||
@@ -962,7 +1083,10 @@ class AppStore {
|
||||
}
|
||||
|
||||
if (!modelToUse) {
|
||||
assistantMessage.content = 'Error: No model available. Please launch an instance first.';
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = 'Error: No model available. Please launch an instance first.';
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -980,7 +1104,10 @@ class AppStore {
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
assistantMessage.content = `Error: ${response.status} - ${errorText}`;
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = `Error: ${response.status} - ${errorText}`;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -988,7 +1115,10 @@ class AppStore {
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
assistantMessage.content = 'Error: No response stream available';
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = 'Error: No response stream available';
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -1016,9 +1146,16 @@ class AppStore {
|
||||
const delta = json.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
fullContent += delta;
|
||||
const { displayContent } = this.stripThinkingTags(fullContent);
|
||||
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
|
||||
this.currentResponse = displayContent;
|
||||
assistantMessage.content = displayContent;
|
||||
|
||||
// Update the assistant message in place (triggers Svelte reactivity)
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
} catch {
|
||||
// Skip malformed JSON
|
||||
@@ -1027,16 +1164,25 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
const { displayContent } = this.stripThinkingTags(fullContent);
|
||||
assistantMessage.content = displayContent;
|
||||
this.currentResponse = '';
|
||||
this.updateActiveConversation();
|
||||
// Final cleanup of the message
|
||||
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
|
||||
} catch (error) {
|
||||
assistantMessage.content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
|
||||
this.updateActiveConversation();
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
this.updateActiveConversation();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1396,6 +1542,8 @@ export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
|
||||
// Actions
|
||||
export const startChat = () => appStore.startChat();
|
||||
@@ -1423,5 +1571,9 @@ export const isSidebarOpen = () => appStore.isSidebarOpen;
|
||||
export const toggleSidebar = () => appStore.toggleSidebar();
|
||||
export const toggleDebugMode = () => appStore.toggleDebugMode();
|
||||
export const setDebugMode = (enabled: boolean) => appStore.setDebugMode(enabled);
|
||||
export const toggleTopologyOnlyMode = () => appStore.toggleTopologyOnlyMode();
|
||||
export const setTopologyOnlyMode = (enabled: boolean) => appStore.setTopologyOnlyMode(enabled);
|
||||
export const toggleChatSidebarVisible = () => appStore.toggleChatSidebarVisible();
|
||||
export const setChatSidebarVisible = (visible: boolean) => appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
|
||||
@@ -1,7 +1,25 @@
|
||||
<script lang="ts">
|
||||
import '../app.css';
|
||||
import { onMount } from 'svelte';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
let { children } = $props();
|
||||
let isPageHidden = $state(false);
|
||||
|
||||
onMount(() => {
|
||||
if (!browser) return;
|
||||
|
||||
// Listen for visibility changes to pause animations when hidden
|
||||
const handleVisibilityChange = () => {
|
||||
isPageHidden = document.visibilityState === 'hidden';
|
||||
};
|
||||
|
||||
document.addEventListener('visibilitychange', handleVisibilityChange);
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('visibilitychange', handleVisibilityChange);
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
<svelte:head>
|
||||
@@ -9,7 +27,7 @@
|
||||
<meta name="description" content="EXO - Distributed AI Cluster Dashboard" />
|
||||
</svelte:head>
|
||||
|
||||
<div class="min-h-screen bg-background text-foreground">
|
||||
<div class="min-h-screen bg-background text-foreground" data-page-hidden={isPageHidden}>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
selectedChatModel,
|
||||
debugMode,
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview
|
||||
} from '$lib/stores/app.svelte';
|
||||
@@ -37,6 +41,8 @@
|
||||
const selectedModelId = $derived(selectedPreviewModelId());
|
||||
const loadingPreviews = $derived(isLoadingPreviews());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
@@ -93,17 +99,35 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
|
||||
// Compute highlighted nodes from hovered instance or hovered preview
|
||||
// Memoized to avoid creating new Sets on every render
|
||||
let lastHighlightedNodesKey = '';
|
||||
let cachedHighlightedNodes: Set<string> = new Set();
|
||||
|
||||
const highlightedNodes = $derived(() => {
|
||||
// Create a key for the current state to enable memoization
|
||||
const previewKey = Array.from(hoveredPreviewNodes).sort().join(',');
|
||||
const currentKey = `${hoveredInstanceId || 'null'}:${previewKey}`;
|
||||
|
||||
// Return cached value if nothing changed
|
||||
if (currentKey === lastHighlightedNodesKey) {
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
|
||||
lastHighlightedNodesKey = currentKey;
|
||||
|
||||
// First check instance hover
|
||||
if (hoveredInstanceId) {
|
||||
const instanceWrapped = instanceData[hoveredInstanceId];
|
||||
return unwrapInstanceNodes(instanceWrapped);
|
||||
cachedHighlightedNodes = unwrapInstanceNodes(instanceWrapped);
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
// Then check preview hover
|
||||
if (hoveredPreviewNodes.size > 0) {
|
||||
return hoveredPreviewNodes;
|
||||
cachedHighlightedNodes = hoveredPreviewNodes;
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
return new Set<string>();
|
||||
cachedHighlightedNodes = new Set<string>();
|
||||
return cachedHighlightedNodes;
|
||||
});
|
||||
|
||||
// Helper to estimate memory from model ID (mirrors ModelCard logic)
|
||||
@@ -472,6 +496,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -489,13 +514,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -505,12 +534,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
};
|
||||
}
|
||||
|
||||
// Debug: Log downloads data when it changes
|
||||
$effect(() => {
|
||||
if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
console.log('[Download Debug] Current downloads:', downloadsData);
|
||||
}
|
||||
});
|
||||
// Debug: Log downloads data when it changes (disabled in production for performance)
|
||||
// Uncomment for debugging:
|
||||
// $effect(() => {
|
||||
// if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
// console.log('[Download Debug] Current downloads:', downloadsData);
|
||||
// }
|
||||
// });
|
||||
|
||||
// Helper to get download status for an instance
|
||||
function getInstanceDownloadStatus(instanceId: string, instanceWrapped: unknown): {
|
||||
@@ -576,6 +606,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -596,13 +627,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, statusText: statusInfo.statusText, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -618,10 +653,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
function getStatusColor(statusText: string): string {
|
||||
switch (statusText) {
|
||||
case 'FAILED': return 'text-red-400';
|
||||
case 'SHUTDOWN': return 'text-gray-400';
|
||||
case 'DOWNLOADING': return 'text-blue-400';
|
||||
case 'LOADING':
|
||||
case 'WARMING UP':
|
||||
case 'WAITING': return 'text-yellow-400';
|
||||
case 'WAITING':
|
||||
case 'INITIALIZING': return 'text-yellow-400';
|
||||
case 'RUNNING': return 'text-teal-400';
|
||||
case 'READY':
|
||||
case 'LOADED': return 'text-green-400';
|
||||
@@ -644,12 +681,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
if (!r) return null;
|
||||
const [kind] = getTagged(r);
|
||||
const statusMap: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: 'WaitingForInitialization',
|
||||
RunnerInitializingBackend: 'InitializingBackend',
|
||||
RunnerWaitingForModel: 'WaitingForModel',
|
||||
RunnerLoading: 'Loading',
|
||||
RunnerLoaded: 'Loaded',
|
||||
RunnerWarmingUp: 'WarmingUp',
|
||||
RunnerReady: 'Ready',
|
||||
RunnerRunning: 'Running',
|
||||
RunnerShutdown: 'Shutdown',
|
||||
RunnerFailed: 'Failed',
|
||||
};
|
||||
return kind ? statusMap[kind] || null : null;
|
||||
@@ -660,12 +700,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
|
||||
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
|
||||
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
|
||||
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
|
||||
if (has('WarmingUp')) return { statusText: 'WARMING UP', statusClass: 'starting' };
|
||||
if (has('Running')) return { statusText: 'RUNNING', statusClass: 'running' };
|
||||
if (has('Ready')) return { statusText: 'READY', statusClass: 'loaded' };
|
||||
if (has('Loaded')) return { statusText: 'LOADED', statusClass: 'loaded' };
|
||||
if (has('WaitingForModel')) return { statusText: 'WAITING', statusClass: 'starting' };
|
||||
if (has('InitializingBackend')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
if (has('WaitingForInitialization')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
|
||||
return { statusText: 'RUNNING', statusClass: 'active' };
|
||||
}
|
||||
@@ -1107,16 +1150,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="shooting-star" style="top: 50%; left: 40%; --duration: 45s; --delay: 30s;"></div>
|
||||
</div>
|
||||
|
||||
<HeaderNav showHome={chatStarted} onHome={handleGoHome} />
|
||||
{#if !topologyOnlyEnabled}
|
||||
<HeaderNav
|
||||
showHome={chatStarted}
|
||||
onHome={handleGoHome}
|
||||
showSidebarToggle={true}
|
||||
sidebarVisible={sidebarVisible}
|
||||
onToggleSidebar={toggleChatSidebarVisible}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="flex-1 flex overflow-hidden relative">
|
||||
<!-- Left: Conversation History Sidebar (always visible) -->
|
||||
<!-- Left: Conversation History Sidebar (hidden in topology-only mode or when toggled off) -->
|
||||
{#if !topologyOnlyEnabled && sidebarVisible}
|
||||
<div class="w-80 flex-shrink-0 border-r border-exo-yellow/10">
|
||||
<ChatSidebar class="h-full" />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if !chatStarted}
|
||||
{#if topologyOnlyEnabled}
|
||||
<!-- TOPOLOGY ONLY MODE: Full-screen topology -->
|
||||
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
|
||||
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
|
||||
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="absolute bottom-4 right-4 p-2 rounded border border-exo-yellow/30 bg-exo-dark-gray/80 hover:border-exo-yellow/50 hover:bg-exo-dark-gray transition-colors cursor-pointer backdrop-blur-sm"
|
||||
title="Exit topology only mode"
|
||||
>
|
||||
<svg class="w-5 h-5 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if !chatStarted}
|
||||
<!-- WELCOME STATE: Topology + Instance Controls (no left sidebar for cleaner look) -->
|
||||
<div class="flex-1 flex overflow-visible relative" in:fade={{ duration: 300 }} out:fade={{ duration: 200 }}>
|
||||
|
||||
@@ -1300,14 +1374,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/80">{filePercent.toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
@@ -1611,13 +1686,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
in:fade={{ duration: 300, delay: 100 }}
|
||||
>
|
||||
<div class="flex-1 overflow-y-auto px-8 py-6" bind:this={chatScrollRef}>
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatMessages scrollParent={chatScrollRef} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
</div>
|
||||
</div>
|
||||
@@ -1655,7 +1730,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<!-- Panel Header -->
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div class="w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse"></div>
|
||||
<h3 class="text-sm text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
|
||||
@@ -1701,28 +1776,28 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="flex justify-between items-start mb-2 pl-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-xs tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400/80 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-sm font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[10px] text-white/60 hover:text-exo-yellow transition-colors mt-0.5"
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
href={`https://huggingface.co/${instanceModelId}`}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
aria-label="View model on Hugging Face"
|
||||
>
|
||||
<span>Hugging Face</span>
|
||||
<svg class="w-3 h-3" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M14 3h7v7"/>
|
||||
<path d="M10 14l11-11"/>
|
||||
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6"/>
|
||||
@@ -1733,68 +1808,84 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-white/60 text-xs font-mono">{instanceInfo.nodeNames.join(', ')}</div>
|
||||
{/if}
|
||||
{#if debugEnabled && instanceConnections.length > 0}
|
||||
<div class="mt-1 space-y-0.5">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[10px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
<div class="mt-2 space-y-1">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[11px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-xs font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-sm font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-1.5 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-2 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
{@const nodePercent = Math.min(100, Math.max(0, nodeProg.progress.percentage))}
|
||||
{@const isExpanded = instanceDownloadExpandedNodes.has(nodeProg.nodeId)}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<button
|
||||
type="button"
|
||||
class="w-full text-left space-y-1.5"
|
||||
onclick={() => toggleInstanceDownloadDetails(nodeProg.nodeId)}
|
||||
>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span class="text-white/80 truncate pr-2">{nodeProg.nodeName}</span>
|
||||
<span class="text-blue-300">{Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%</span>
|
||||
<span class="flex items-center gap-1 text-blue-300">
|
||||
{nodePercent.toFixed(1)}%
|
||||
<svg class="w-3 h-3 text-exo-light-gray" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path d="M6 8l4 4 4-4" class={isExpanded ? 'transform rotate-180 origin-center transition-transform duration-150' : 'transition-transform duration-150'}></path>
|
||||
</svg>
|
||||
</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mb-1.5">
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-blue-500/80 transition-all duration-300"
|
||||
style="width: {Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {nodePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span>{formatBytes(nodeProg.progress.downloadedBytes)} / {formatBytes(nodeProg.progress.totalBytes)}</span>
|
||||
<span>{formatSpeed(nodeProg.progress.speed)} • ETA {formatEta(nodeProg.progress.etaMs)}</span>
|
||||
</div>
|
||||
{#if nodeProg.progress.files.length > 0}
|
||||
{@const inProgressFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) < 100)}
|
||||
{@const completedFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) >= 100)}
|
||||
{#if inProgressFiles.length > 0}
|
||||
<div class="space-y-1">
|
||||
{#each inProgressFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/80">
|
||||
<div class="flex items-center justify-between">
|
||||
</button>
|
||||
|
||||
{#if isExpanded}
|
||||
<div class="mt-2 space-y-1.5">
|
||||
{#if nodeProg.progress.files.length === 0}
|
||||
<div class="text-[11px] font-mono text-exo-light-gray/70">No file details reported.</div>
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/70">{Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/50 rounded-sm overflow-hidden mt-0.5">
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70"
|
||||
style="width: {Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5">
|
||||
@@ -1803,27 +1894,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{#if completedFiles.length > 0}
|
||||
<div class="pt-1 space-y-0.5">
|
||||
{#each completedFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/70 flex items-center justify-between">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/60">100%</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-sm {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-xs text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-xs {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -345,13 +345,19 @@
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-dark-gray/60 p-3 space-y-2">
|
||||
<div class="flex items-center justify-between gap-3">
|
||||
<div class="min-w-0 space-y-0.5">
|
||||
<div class="text-sm font-mono text-white truncate">{model.prettyName ?? model.modelId}</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono truncate">
|
||||
{model.modelId}
|
||||
</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
<div
|
||||
class="text-xs font-mono text-white truncate"
|
||||
title={model.prettyName ?? model.modelId}
|
||||
>{model.prettyName ?? model.modelId}</div>
|
||||
<div
|
||||
class="text-[10px] text-exo-light-gray font-mono truncate"
|
||||
title={model.modelId}
|
||||
>{model.modelId}</div>
|
||||
{#if model.status !== 'completed'}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
|
||||
@@ -426,14 +432,14 @@
|
||||
<style>
|
||||
.downloads-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
|
||||
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
|
||||
}
|
||||
@media (min-width: 1024px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(3, minmax(0, 1fr));
|
||||
}
|
||||
}
|
||||
@media (min-width: 1440px) {
|
||||
@media (min-width: 1600px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(4, minmax(0, 1fr));
|
||||
}
|
||||
|
||||
@@ -29,10 +29,12 @@ dependencies = [
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.30.1",
|
||||
"mlx>=0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm>=0.28.3",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -13,6 +13,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
@@ -21,11 +27,13 @@ from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -56,7 +64,7 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
HIDE_THINKING = False
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
@@ -161,7 +169,9 @@ class API:
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions")(self.chat_completions)
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -177,17 +187,32 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=command.model_meta,
|
||||
)
|
||||
|
||||
async def create_instance(
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
command = CreateInstance(instance=payload.instance)
|
||||
instance = payload.instance
|
||||
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
|
||||
required_memory = model_meta.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
if required_memory > available_memory:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB",
|
||||
)
|
||||
|
||||
command = CreateInstance(
|
||||
instance=instance,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
async def get_placement(
|
||||
@@ -207,7 +232,6 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -263,7 +287,6 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -354,32 +377,52 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
is_thinking = False
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
if HIDE_THINKING:
|
||||
if chunk.text == "<think>":
|
||||
is_thinking = True
|
||||
if chunk.text == "</think>":
|
||||
is_thinking = False
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
if not (is_thinking and HIDE_THINKING):
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -394,6 +437,59 @@ class API:
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
@@ -401,10 +497,12 @@ class API:
|
||||
|
||||
async def chat_completions(
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> StreamingResponse:
|
||||
"""Handle chat completions with proper streaming response."""
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -419,17 +517,21 @@ class API:
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
for profile in self.state.node_profiles.values():
|
||||
total_available += profile.memory.ram_available
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available
|
||||
|
||||
return total_available
|
||||
|
||||
@@ -443,6 +545,8 @@ class API:
|
||||
name=card.name,
|
||||
description=card.description,
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -459,7 +563,7 @@ class API:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._applystate)
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
@@ -471,7 +575,7 @@ class API:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _applystate(self):
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
|
||||
@@ -158,7 +158,6 @@ class Master:
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_profiles,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
@@ -201,7 +200,9 @@ class Master:
|
||||
async def _plan(self) -> None:
|
||||
while True:
|
||||
# kill broken instances
|
||||
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
|
||||
connected_node_ids = set(
|
||||
[x.node_id for x in self.state.topology.list_nodes()]
|
||||
)
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
|
||||
@@ -6,11 +6,10 @@ from typing import Sequence
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -20,10 +19,9 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -52,16 +50,19 @@ def place_instance(
|
||||
command: PlaceInstance,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
cycles = topology.get_cycles() + [[node] for node in all_nodes]
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, node_profiles, command.model_meta.storage_size
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
|
||||
)
|
||||
if len(cycles_with_sufficient_memory) == 0:
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, command.model_meta.storage_size
|
||||
)
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
@@ -69,15 +70,13 @@ def place_instance(
|
||||
smallest_tb_cycles = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if topology.get_subgraph_from_nodes(
|
||||
[node.node_id for node in cycle]
|
||||
).is_thunderbolt_cycle([node.node_id for node in cycle])
|
||||
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
||||
]
|
||||
|
||||
if smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
|
||||
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if any(topology.node_is_leaf(node.node_id) for node in cycle)
|
||||
@@ -86,7 +85,11 @@ def place_instance(
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(node.node_profile.memory.ram_available for node in cycle),
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
),
|
||||
start=Memory(),
|
||||
),
|
||||
)
|
||||
@@ -95,16 +98,14 @@ def place_instance(
|
||||
command.model_meta, selected_cycle, command.sharding
|
||||
)
|
||||
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
|
||||
[node.node_id for node in selected_cycle]
|
||||
)
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
|
||||
|
||||
instance_id = InstanceId()
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.warning(
|
||||
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
|
||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
command.instance_meta = InstanceMeta.MlxRing
|
||||
@@ -112,32 +113,33 @@ def place_instance(
|
||||
# TODO: Single node instances
|
||||
match command.instance_meta:
|
||||
case InstanceMeta.MlxJaccl:
|
||||
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
||||
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
|
||||
selected_cycle,
|
||||
cycle_digraph,
|
||||
)
|
||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||
coordinator=selected_cycle[0].node_id,
|
||||
selected_cycle,
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
cycle_digraph=cycle_digraph,
|
||||
)
|
||||
target_instances[instance_id] = MlxJacclInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
hosts_by_node = get_mlx_ring_hosts_by_node(
|
||||
selected_cycle=selected_cycle,
|
||||
cycle_digraph=cycle_digraph,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
target_instances[instance_id] = MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[
|
||||
Host(
|
||||
ip=host.ip,
|
||||
port=random_ephemeral_port(),
|
||||
)
|
||||
for host in hosts
|
||||
],
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import TypeGuard, cast
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
@@ -8,7 +9,7 @@ from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
@@ -23,32 +24,27 @@ class NodeWithProfile(BaseModel):
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
|
||||
|
||||
def filter_cycles_by_memory(
|
||||
cycles: list[list[NodeId]],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
required_memory: Memory,
|
||||
) -> list[list[NodeWithProfile]]:
|
||||
filtered_cycles: list[list[NodeWithProfile]] = []
|
||||
cycles: list[list[NodeInfo]], required_memory: Memory
|
||||
) -> list[list[NodeInfo]]:
|
||||
filtered_cycles: list[list[NodeInfo]] = []
|
||||
for cycle in cycles:
|
||||
if not all(node in node_profiles for node in cycle):
|
||||
if not narrow_all_nodes(cycle):
|
||||
continue
|
||||
|
||||
total_mem = sum(
|
||||
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
|
||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
||||
)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(
|
||||
[
|
||||
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
|
||||
for node in cycle
|
||||
]
|
||||
)
|
||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(
|
||||
cycles: list[list[NodeWithProfile]],
|
||||
) -> list[list[NodeWithProfile]]:
|
||||
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
@@ -139,9 +135,11 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
def get_shard_assignments(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
selected_cycle: list[NodeInfo],
|
||||
sharding: Sharding,
|
||||
) -> ShardAssignments:
|
||||
if not narrow_all_nodes(selected_cycle):
|
||||
raise ValueError("All nodes must have profiles to create shard assignments")
|
||||
match sharding:
|
||||
case Sharding.Pipeline:
|
||||
return get_shard_assignments_for_pipeline_parallel(
|
||||
@@ -178,16 +176,17 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
current_node = cycle[i]
|
||||
next_node = cycle[(i + 1) % len(cycle)]
|
||||
|
||||
for src, sink, connection in cycle_digraph.list_connections():
|
||||
if not isinstance(connection, SocketConnection):
|
||||
continue
|
||||
|
||||
if src == current_node and sink == next_node:
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == current_node.node_id
|
||||
and connection.send_back_node_id == next_node.node_id
|
||||
):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
assert connection.send_back_multiaddr is not None
|
||||
host = Host(
|
||||
ip=connection.sink_multiaddr.ip_address,
|
||||
port=connection.sink_multiaddr.port,
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
@@ -195,7 +194,8 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
def get_mlx_ibv_devices_matrix(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
) -> list[list[str | None]]:
|
||||
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
|
||||
@@ -204,7 +204,6 @@ def get_mlx_jaccl_devices_matrix(
|
||||
to device j, or None if no connection exists or no interface name is found.
|
||||
Diagonal elements are always None.
|
||||
"""
|
||||
selected_cycle = list(cycle_digraph.list_nodes())
|
||||
num_nodes = len(selected_cycle)
|
||||
matrix: list[list[str | None]] = [
|
||||
[None for _ in range(num_nodes)] for _ in range(num_nodes)
|
||||
@@ -215,55 +214,191 @@ def get_mlx_jaccl_devices_matrix(
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(conn, RDMAConnection):
|
||||
matrix[i][j] = conn.source_rdma_iface
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_rdma_interface_name_for_ip(
|
||||
connection_ip, node_i
|
||||
):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all-to-all RDMA connections"
|
||||
"Current ibv backend requires all-to-all rdma connections"
|
||||
)
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _find_connection_ip(
|
||||
node_i: NodeId,
|
||||
node_j: NodeId,
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[str]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
|
||||
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(connection, SocketConnection):
|
||||
yield connection.sink_multiaddr.ip_address
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_rdma_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
|
||||
continue
|
||||
logger.info(f" | {interface.name}: {interface.ip_address}")
|
||||
if interface.ip_address != ip_address:
|
||||
continue
|
||||
|
||||
logger.info("Found")
|
||||
return f"rdma_{interface.name}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = iface_map.get("en1")
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
"""Generate per-node host lists for MLX ring backend.
|
||||
|
||||
Each node gets a list where:
|
||||
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
|
||||
- Left/right neighbors: actual connection IPs
|
||||
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
|
||||
"""
|
||||
world_size = len(selected_cycle)
|
||||
if world_size == 0:
|
||||
return {}
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
for rank, node in enumerate(selected_cycle):
|
||||
node_id = node.node_id
|
||||
left_rank = (rank - 1) % world_size
|
||||
right_rank = (rank + 1) % world_size
|
||||
|
||||
hosts_for_node: list[Host] = []
|
||||
|
||||
for idx, other_node in enumerate(selected_cycle):
|
||||
if idx == rank:
|
||||
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
continue
|
||||
|
||||
if idx not in {left_rank, right_rank}:
|
||||
# Placeholder IP from RFC 5737 TEST-NET-2
|
||||
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
|
||||
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
|
||||
|
||||
hosts_by_node[node_id] = hosts_for_node
|
||||
|
||||
return hosts_by_node
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
coordinator: NodeId,
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
cycle_digraph: Topology,
|
||||
) -> dict[NodeId, str]:
|
||||
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
|
||||
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
|
||||
|
||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
selected_cycle = list(cycle_digraph.list_nodes())
|
||||
logger.info(f"Selecting coordinator: {coordinator}")
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
|
||||
def get_ip_for_node(n: NodeId) -> str:
|
||||
if n == coordinator:
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
|
||||
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
|
||||
)
|
||||
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
||||
|
||||
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}
|
||||
return {
|
||||
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
|
||||
}
|
||||
|
||||
@@ -1,36 +1,67 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
|
||||
|
||||
def create_node_profile(memory: int) -> NodePerformanceProfile:
|
||||
return NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
|
||||
return _create_node
|
||||
|
||||
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||
)
|
||||
@pytest.fixture
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
ip_counter = 1
|
||||
|
||||
def _create_connection(
|
||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
||||
) -> Connection:
|
||||
nonlocal port_counter
|
||||
nonlocal ip_counter
|
||||
# assign unique ips
|
||||
ip_counter += 1
|
||||
if send_back_port is None:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
def create_rdma_connection(iface: int) -> RDMAConnection:
|
||||
return RDMAConnection(
|
||||
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
|
||||
)
|
||||
return _create_connection
|
||||
|
||||
@@ -19,13 +19,15 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodeGatheredInfo,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
@@ -81,14 +83,21 @@ async def test_master():
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeGatheredInfo(
|
||||
NodePerformanceMeasured(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
info=MemoryUsage(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -114,6 +123,8 @@ async def test_master():
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
@@ -152,34 +163,40 @@ async def test_master():
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodeGatheredInfo)
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event.instance == MlxRingInstance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
|
||||
# Validate the shard assignments
|
||||
expected_shard_assignments = ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
)
|
||||
assert created_instance.shard_assignments == expected_shard_assignments
|
||||
# For single-node, hosts_by_node should have one entry with self-binding
|
||||
assert len(created_instance.hosts_by_node) == 1
|
||||
assert node_id in created_instance.hosts_by_node
|
||||
assert len(created_instance.hosts_by_node[node_id]) == 1
|
||||
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
@@ -5,20 +7,14 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_connection,
|
||||
create_node_profile,
|
||||
create_rdma_connection,
|
||||
)
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -30,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def instance() -> Instance:
|
||||
return MlxRingInstance(
|
||||
@@ -37,7 +38,8 @@ def instance() -> Instance:
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +50,8 @@ def model_meta() -> ModelMetadata:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=10,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,33 +77,34 @@ def test_get_instance_placements_create_instance(
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
model_meta.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
cic = place_instance_command(model_meta)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(available_memory[0]),
|
||||
node_id_b: create_node_profile(available_memory[1]),
|
||||
node_id_c: create_node_profile(available_memory[2]),
|
||||
}
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_c, create_connection(2))
|
||||
topology.add_connection(node_id_c, node_id_a, create_connection(3))
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
@@ -125,20 +130,23 @@ def test_get_instance_placements_create_instance(
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -149,20 +157,23 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1001 * 1024)}
|
||||
topology.add_node(create_node(1001 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -173,22 +184,25 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1001),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
||||
place_instance(cic, topology, {}, profiles)
|
||||
place_instance(cic, topology, {})
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(instance: Instance):
|
||||
@@ -233,103 +247,179 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
topology = Topology()
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
|
||||
model_meta.storage_size = Memory.from_bytes(1000)
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
model_meta.n_layers = 12
|
||||
|
||||
# Create node ids
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
node_id_d = NodeId()
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(500),
|
||||
node_id_b: create_node_profile(600),
|
||||
node_id_c: create_node_profile(600),
|
||||
node_id_d: create_node_profile(500),
|
||||
}
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
topology.add_node(create_node(800, node_id_c))
|
||||
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_node(node_id_d)
|
||||
# D-E-F cycle total memory = 1800 (> A-B-C total)
|
||||
topology.add_node(create_node(600, node_id_d))
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
|
||||
# Daisy chain topology
|
||||
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_a, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_c, create_connection(1))
|
||||
topology.add_connection(node_id_c, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_c, node_id_d, create_connection(1))
|
||||
topology.add_connection(node_id_d, node_id_c, create_connection(1))
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
logger.info(list(topology.list_connections()))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# assert
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
assert len(placements) == 1
|
||||
instance = list(placements.values())[0]
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
|
||||
(node_id_c, node_id_d)
|
||||
)
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
topology = Topology()
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
node_c = NodeId()
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
|
||||
profiles = {
|
||||
node_a: create_node_profile(500),
|
||||
node_b: create_node_profile(500),
|
||||
node_c: create_node_profile(500),
|
||||
}
|
||||
node_a = create_node(500, node_id_a)
|
||||
node_b = create_node(500, node_id_b)
|
||||
node_c = create_node(500, node_id_c)
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(
|
||||
name="en0",
|
||||
ip_address="192.168.1.100",
|
||||
)
|
||||
ethernet_conn = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
|
||||
)
|
||||
|
||||
profiles[node_a].network_interfaces = [ethernet_interface]
|
||||
profiles[node_b].network_interfaces = [ethernet_interface]
|
||||
profiles[node_c].network_interfaces = [ethernet_interface]
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
conn_a_b = create_connection(node_id_a, node_id_b)
|
||||
conn_b_c = create_connection(node_id_b, node_id_c)
|
||||
conn_c_a = create_connection(node_id_c, node_id_a)
|
||||
|
||||
conn_b_a = create_connection(node_id_b, node_id_a)
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology.add_connection(node_a, node_b, create_rdma_connection(3))
|
||||
topology.add_connection(node_b, node_c, create_rdma_connection(4))
|
||||
topology.add_connection(node_c, node_a, create_rdma_connection(5))
|
||||
topology.add_connection(node_b, node_a, create_rdma_connection(3))
|
||||
topology.add_connection(node_c, node_b, create_rdma_connection(4))
|
||||
topology.add_connection(node_a, node_c, create_rdma_connection(5))
|
||||
|
||||
topology.add_connection(node_a, node_b, ethernet_conn)
|
||||
topology.add_connection(node_b, node_c, ethernet_conn)
|
||||
topology.add_connection(node_c, node_a, ethernet_conn)
|
||||
topology.add_connection(node_a, node_c, ethernet_conn)
|
||||
topology.add_connection(node_b, node_a, ethernet_conn)
|
||||
topology.add_connection(node_c, node_b, ethernet_conn)
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cic = PlaceInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
@@ -339,7 +429,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -347,10 +437,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.jaccl_devices is not None
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.jaccl_coordinators is not None
|
||||
|
||||
matrix = instance.jaccl_devices
|
||||
matrix = instance.ibv_devices
|
||||
assert len(matrix) == 3
|
||||
|
||||
for i in range(3):
|
||||
@@ -359,15 +449,15 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
|
||||
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
|
||||
|
||||
idx_a = node_to_idx[node_a]
|
||||
idx_b = node_to_idx[node_b]
|
||||
idx_c = node_to_idx[node_c]
|
||||
idx_a = node_to_idx[node_id_a]
|
||||
idx_b = node_to_idx[node_id_b]
|
||||
idx_c = node_to_idx[node_id_c]
|
||||
|
||||
logger.info(matrix)
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.jaccl_coordinators) == 3
|
||||
|
||||
@@ -1,48 +1,56 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import create_connection, create_node_profile
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
topology = Topology()
|
||||
return topology
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(1)
|
||||
connection2 = create_connection(2)
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology.add_connection(node1_id, node2_id, connection1)
|
||||
topology.add_connection(node2_id, node1_id, connection2)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
assert len(cycles) == 1
|
||||
assert len(cycles[0]) == 2
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_bytes(1)
|
||||
)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -50,65 +58,64 @@ def test_filter_cycles_by_memory():
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory():
|
||||
def test_filter_cycles_by_insufficient_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(1)
|
||||
connection2 = create_connection(2)
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology.add_connection(node1_id, node2_id, connection1)
|
||||
topology.add_connection(node2_id, node1_id, connection2)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
|
||||
topology.get_cycles(), Memory.from_kb(2001)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 0
|
||||
|
||||
|
||||
def test_filter_multiple_cycles_by_memory():
|
||||
def test_filter_multiple_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_kb(1500)
|
||||
)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -120,38 +127,31 @@ def test_filter_multiple_cycles_by_memory():
|
||||
}
|
||||
|
||||
|
||||
def test_get_smallest_cycles():
|
||||
def test_get_smallest_cycles(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
|
||||
cycles = [
|
||||
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||
for cycle in topology.get_cycles()
|
||||
]
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
smallest_cycles = get_smallest_cycles(topology.get_cycles())
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
@@ -168,6 +168,9 @@ def test_get_smallest_cycles():
|
||||
],
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
@@ -176,37 +179,29 @@ def test_get_shard_assignments(
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(available_memory[0] * 1024)
|
||||
node_b = create_node_profile(available_memory[1] * 1024)
|
||||
node_c = create_node_profile(available_memory[2] * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(available_memory[0] * 1024, node_a_id)
|
||||
node_b = create_node(available_memory[1] * 1024, node_b_id)
|
||||
node_c = create_node(available_memory[2] * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_c_id, create_connection(2))
|
||||
topology.add_connection(node_c_id, node_a_id, create_connection(3))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(4))
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
cycles = [
|
||||
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||
for cycle in topology.get_cycles()
|
||||
]
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# act
|
||||
@@ -235,21 +230,28 @@ def test_get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph():
|
||||
def test_get_hosts_from_subgraph(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
node_a = create_node(500, node_a_id)
|
||||
node_b = create_node(500, node_b_id)
|
||||
node_c = create_node(1000, node_c_id)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
@@ -257,47 +259,108 @@ def test_get_hosts_from_subgraph():
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip=("169.254.0.2"), port=1234),
|
||||
Host(ip=("169.254.0.3"), port=1234),
|
||||
Host(ip=("169.254.0.4"), port=1234),
|
||||
Host(ip=("169.254.0.2"), port=5001),
|
||||
Host(ip=("169.254.0.3"), port=5002),
|
||||
Host(ip=("169.254.0.4"), port=5003),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
def test_get_mlx_jaccl_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||
|
||||
conn_a_b = create_connection(1)
|
||||
conn_b_a = create_connection(2)
|
||||
conn_b_c = create_connection(3)
|
||||
conn_c_b = create_connection(4)
|
||||
conn_c_a = create_connection(5)
|
||||
conn_a_c = create_connection(6)
|
||||
# Update node profiles with network interfaces before adding to topology
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, conn_a_b)
|
||||
topology.add_connection(node_b_id, node_a_id, conn_b_a)
|
||||
topology.add_connection(node_b_id, node_c_id, conn_b_c)
|
||||
topology.add_connection(node_c_id, node_b_id, conn_c_b)
|
||||
topology.add_connection(node_c_id, node_a_id, conn_c_a)
|
||||
topology.add_connection(node_a_id, node_c_id, conn_a_c)
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cycle = [node_a, node_b, node_c]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_jaccl_coordinators(
|
||||
node_a_id, coordinator_port=5000, cycle_digraph=topology
|
||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||
)
|
||||
|
||||
# assert
|
||||
@@ -326,11 +389,11 @@ def test_get_mlx_jaccl_coordinators():
|
||||
|
||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
|
||||
"node_b should use the IP from conn_b_a"
|
||||
)
|
||||
assert coordinators[node_b_id] == (
|
||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_b should use the IP from conn_b_a"
|
||||
|
||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
|
||||
"node_c should use the IP from conn_c_a"
|
||||
)
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -17,15 +16,20 @@ def topology() -> Topology:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection() -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
def connection() -> Connection:
|
||||
return Connection(
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryUsage.from_bytes(
|
||||
memory_profile = MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
)
|
||||
system_profile = SystemPerformanceProfile()
|
||||
@@ -39,85 +43,162 @@ def node_profile() -> NodePerformanceProfile:
|
||||
)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology):
|
||||
@pytest.fixture
|
||||
def connection_profile() -> ConnectionProfile:
|
||||
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
# arrange
|
||||
node_id = NodeId()
|
||||
|
||||
# act
|
||||
topology.add_node(node_id)
|
||||
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
|
||||
|
||||
# assert
|
||||
assert topology.node_is_leaf(node_id)
|
||||
data = topology.get_node_profile(node_id)
|
||||
assert data == node_profile
|
||||
|
||||
|
||||
def test_add_connection(topology: Topology, connection: SocketConnection):
|
||||
def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
data = list(conn for _, _, conn in topology.list_connections())
|
||||
data = topology.get_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
assert data == [connection]
|
||||
assert data == connection.connection_profile
|
||||
|
||||
assert topology.node_is_leaf(node_a)
|
||||
assert topology.node_is_leaf(node_b)
|
||||
|
||||
def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_node_profile(
|
||||
connection.local_node_id, node_profile=new_node_profile
|
||||
)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(
|
||||
throughput=2000, latency=2000, jitter=2000
|
||||
)
|
||||
connection = Connection(
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
data = topology.get_connection_profile(connection)
|
||||
assert data == new_connection_profile
|
||||
|
||||
|
||||
def test_remove_connection_still_connected(
|
||||
topology: Topology, connection: SocketConnection
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_connection(node_a, node_b, connection)
|
||||
topology.remove_connection(connection)
|
||||
|
||||
# assert
|
||||
assert list(topology.get_all_connections_between(node_a, node_b)) == []
|
||||
assert topology.get_connection_profile(connection) is None
|
||||
|
||||
|
||||
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
|
||||
def test_remove_node_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_node(node_b)
|
||||
topology.remove_node(connection.local_node_id)
|
||||
|
||||
# assert
|
||||
assert list(topology.out_edges(node_a)) == []
|
||||
assert topology.get_node_profile(connection.local_node_id) is None
|
||||
|
||||
|
||||
def test_list_nodes(topology: Topology, connection: SocketConnection):
|
||||
def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
nodes = list(topology.list_nodes())
|
||||
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeId) for node in nodes)
|
||||
assert {node for node in nodes} == {node_a, node_b}
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
}
|
||||
|
||||
@@ -11,8 +11,10 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeTimedOut,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
@@ -25,23 +27,13 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import RDMAConnection
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
MacmonMetrics,
|
||||
MacTBConnections,
|
||||
MacTBIdentifiers,
|
||||
MemoryUsage,
|
||||
MiscData,
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
)
|
||||
|
||||
|
||||
def event_apply(event: Event, state: State) -> State:
|
||||
@@ -55,12 +47,16 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case NodeCreated():
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodePerformanceMeasured():
|
||||
return apply_node_performance_measured(event, state)
|
||||
case NodeDownloadProgress():
|
||||
return apply_node_download_progress(event, state)
|
||||
case NodeGatheredInfo():
|
||||
return apply_node_gathered_info(event, state)
|
||||
case NodeMemoryMeasured():
|
||||
return apply_node_memory_measured(event, state)
|
||||
case RunnerDeleted():
|
||||
return apply_runner_deleted(event, state)
|
||||
case RunnerStatusUpdated():
|
||||
@@ -192,7 +188,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
|
||||
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology = copy.copy(state.topology)
|
||||
state.topology.remove_node(event.node_id)
|
||||
node_profiles = {
|
||||
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
||||
@@ -200,12 +196,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
}
|
||||
downloads = {
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"node_profiles": node_profiles,
|
||||
"last_seen": last_seen,
|
||||
@@ -213,69 +205,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
)
|
||||
|
||||
|
||||
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_node(event.node_id)
|
||||
info = event.info
|
||||
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
|
||||
# TODO: should be broken up into individual events instead of this monster
|
||||
match info:
|
||||
case MacmonMetrics():
|
||||
profile.system = info.system_profile
|
||||
profile.memory = info.memory
|
||||
case MemoryUsage():
|
||||
profile.memory = info
|
||||
case NodeConfig():
|
||||
pass
|
||||
case MiscData():
|
||||
profile.friendly_name = info.friendly_name
|
||||
case StaticNodeInformation():
|
||||
profile.model_id = info.model
|
||||
profile.chip_id = info.chip
|
||||
# TODO: makes me slightly sad
|
||||
case NodeNetworkInterfaces():
|
||||
profile.network_interfaces = info.ifaces
|
||||
case MacTBIdentifiers():
|
||||
profile.tb_interfaces = info.idents
|
||||
case MacTBConnections():
|
||||
conn_map = {
|
||||
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
|
||||
for nid in state.node_profiles
|
||||
for tb_ident in state.node_profiles[nid].tb_interfaces
|
||||
}
|
||||
as_rdma_conns = [
|
||||
(
|
||||
conn_map[tb_conn.sink_uuid][0],
|
||||
RDMAConnection(
|
||||
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
|
||||
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
|
||||
),
|
||||
)
|
||||
for tb_conn in info.conns
|
||||
if tb_conn.source_uuid in conn_map
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
|
||||
|
||||
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
|
||||
new_profiles = {**state.node_profiles, event.node_id: profile}
|
||||
def apply_node_performance_measured(
|
||||
event: NodePerformanceMeasured, state: State
|
||||
) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: event.node_profile,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
||||
topology = copy.copy(state.topology)
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": new_profiles,
|
||||
"last_seen": last_seen,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
|
||||
existing = state.node_profiles.get(event.node_id)
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
if existing is None:
|
||||
created = NodePerformanceProfile(
|
||||
model_id="unknown",
|
||||
chip_id="unknown",
|
||||
friendly_name="Unknown",
|
||||
memory=event.memory,
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(
|
||||
# TODO: flops_fp16=0.0,
|
||||
gpu_usage=0.0,
|
||||
temp=0.0,
|
||||
sys_power=0.0,
|
||||
pcpu_usage=0.0,
|
||||
ecpu_usage=0.0,
|
||||
ane_power=0.0,
|
||||
),
|
||||
)
|
||||
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: created,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
# TODO: NodeCreated
|
||||
topology.update_node_profile(event.node_id, created)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": created_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
updated = existing.model_copy(update={"memory": event.memory})
|
||||
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: updated,
|
||||
}
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, updated)
|
||||
return state.model_copy(
|
||||
update={"node_profiles": updated_profiles, "topology": topology}
|
||||
)
|
||||
|
||||
|
||||
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_connection(event.source, event.sink, event.edge)
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_connection(event.edge)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.remove_connection(event.sink, event.source, event.edge)
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
topology.remove_connection(event.edge)
|
||||
# TODO: Clean up removing the reverse connection
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@@ -38,7 +38,6 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
||||
|
||||
# Identity (config)
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
|
||||
@@ -24,8 +24,6 @@ class _InterceptHandler(logging.Handler):
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
return
|
||||
|
||||
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
|
||||
@@ -51,6 +51,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
@@ -64,6 +66,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
@@ -135,6 +139,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
@@ -148,6 +154,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Kimi K2 Thinking (4-bit)",
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
@@ -162,6 +170,38 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.1 8B (4-bit)",
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
short_id="llama-3.1-8b-8bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.1 8B (8-bit)",
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
short_id="llama-3.1-8b-bf16",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
pretty_name="Llama 3.1 8B (BF16)",
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
@@ -175,6 +215,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.1 70B (4-bit)",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.2
|
||||
@@ -189,6 +231,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 1B (4-bit)",
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
@@ -202,6 +246,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 3B (4-bit)",
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
@@ -215,6 +261,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 3B (8-bit)",
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.3
|
||||
@@ -229,6 +277,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
@@ -242,6 +292,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B (8-bit)",
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
@@ -255,20 +307,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B (FP16)",
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
# phi-3
|
||||
"phi-3-mini": ModelCard(
|
||||
short_id="phi-3-mini",
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
name="Phi 3 Mini 128k (4-bit)",
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
pretty_name="Phi 3 Mini 128k (4-bit)",
|
||||
storage_size=Memory.from_mb(2099),
|
||||
n_layers=32,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# qwen3
|
||||
@@ -283,6 +323,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 0.6B (4-bit)",
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
@@ -296,6 +338,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 0.6B (8-bit)",
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
@@ -309,6 +353,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 30B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
@@ -322,6 +368,68 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 30B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
@@ -335,6 +443,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 235B A22B (4-bit)",
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
@@ -348,6 +458,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 235B A22B (8-bit)",
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
@@ -361,6 +473,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
@@ -374,77 +488,84 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# granite
|
||||
"granite-3.3-2b": ModelCard(
|
||||
short_id="granite-3.3-2b",
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
name="Granite 3.3 2B (FP16)",
|
||||
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
pretty_name="Granite 3.3 2B (FP16)",
|
||||
storage_size=Memory.from_mb(4951),
|
||||
n_layers=40,
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "granite-3.3-8b": ModelCard(
|
||||
# short_id="granite-3.3-8b",
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# name="Granite 3.3 8B",
|
||||
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# Needs to be quantized g32 or g16.
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
pretty_name="GLM 4.5 Air 8bit",
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
short_id="glm-4.5-air-bf16",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
pretty_name="GLM 4.5 Air bf16",
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# pretty_name="Granite 3.3 8B",
|
||||
# storage_size=Memory.from_kb(15958720),
|
||||
# n_layers=40,
|
||||
# ),
|
||||
# ),
|
||||
# smol-lm
|
||||
# "smol-lm-135m": ModelCard(
|
||||
# short_id="smol-lm-135m",
|
||||
# model_id="mlx-community/SmolLM-135M-4bit",
|
||||
# name="Smol LM 135M",
|
||||
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
|
||||
# pretty_name="Smol LM 135M",
|
||||
# storage_size=Memory.from_kb(73940),
|
||||
# n_layers=30,
|
||||
# ),
|
||||
# ),
|
||||
# gpt-oss
|
||||
# "gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
# short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# storage_size=Memory.from_kb(68_996_301),
|
||||
# n_layers=36,
|
||||
# hidden_size=2880,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# "gpt-oss-20b-4bit": ModelCard(
|
||||
# short_id="gpt-oss-20b-4bit",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# storage_size=Memory.from_kb(11_744_051),
|
||||
# n_layers=24,
|
||||
# hidden_size=2880,
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# storage_size=Memory.from_kb(133_000_000),
|
||||
# n_layers=88,
|
||||
# hidden_size=12288,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
|
||||
@@ -6,6 +6,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
@@ -25,6 +26,7 @@ class ConfigData(BaseModel):
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
@@ -106,10 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
model_card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_id,
|
||||
pretty_name=model_card.name if model_card is not None else model_id,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.metadata.supports_tensor
|
||||
if model_card is not None
|
||||
else False,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,8 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
|
||||
@@ -39,4 +39,7 @@ def test_apply_two_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event2), state
|
||||
)
|
||||
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import Connection
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
@@ -11,12 +11,14 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_a = NodeId("node-a")
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
connection = Connection(
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
)
|
||||
|
||||
state = State()
|
||||
state.topology.add_connection(node_a, node_b, connection)
|
||||
state.topology.add_connection(connection)
|
||||
|
||||
json_repr = state.model_dump_json()
|
||||
restored_state = State.model_validate_json(json_repr)
|
||||
|
||||
@@ -1,219 +1,203 @@
|
||||
import contextlib
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
|
||||
import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
|
||||
|
||||
class TopologySnapshot(BaseModel):
|
||||
nodes: Sequence[NodeId]
|
||||
connections: Mapping[
|
||||
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
|
||||
]
|
||||
nodes: list[NodeInfo]
|
||||
connections: list[Connection]
|
||||
|
||||
model_config = ConfigDict(frozen=True, extra="forbid")
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Topology:
|
||||
# the _graph can be used as a int -> NodeId map.
|
||||
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
|
||||
init=False, default_factory=rx.PyDiGraph
|
||||
)
|
||||
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
|
||||
def __init__(self) -> None:
|
||||
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
|
||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
|
||||
|
||||
def to_snapshot(self) -> TopologySnapshot:
|
||||
return TopologySnapshot(
|
||||
nodes=list(self.list_nodes()), connections=self.map_connections()
|
||||
nodes=list(self.list_nodes()),
|
||||
connections=list(self.list_connections()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
|
||||
topology = cls()
|
||||
|
||||
for node_id in snapshot.nodes:
|
||||
for node in snapshot.nodes:
|
||||
with contextlib.suppress(ValueError):
|
||||
topology.add_node(node_id)
|
||||
topology.add_node(node)
|
||||
|
||||
for source in snapshot.connections:
|
||||
for sink in snapshot.connections[source]:
|
||||
for conn in snapshot.connections[source][sink]:
|
||||
topology.add_connection(source, sink, conn)
|
||||
for connection in snapshot.connections:
|
||||
topology.add_connection(connection)
|
||||
|
||||
return topology
|
||||
|
||||
def add_node(self, node_id: NodeId) -> None:
|
||||
if node_id in self._vertex_indices:
|
||||
def add_node(self, node: NodeInfo) -> None:
|
||||
if node.node_id in self._node_id_to_rx_id_map:
|
||||
return
|
||||
rx_id = self._graph.add_node(node_id)
|
||||
self._vertex_indices[node_id] = rx_id
|
||||
rx_id = self._graph.add_node(node)
|
||||
self._node_id_to_rx_id_map[node.node_id] = rx_id
|
||||
self._rx_id_to_node_id_map[rx_id] = node.node_id
|
||||
|
||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||
return (
|
||||
node_id in self._vertex_indices
|
||||
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
|
||||
node_id in self._node_id_to_rx_id_map
|
||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
||||
)
|
||||
|
||||
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
||||
return [
|
||||
self._graph[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
|
||||
self._rx_id_to_node_id_map[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
|
||||
]
|
||||
|
||||
def out_edges(
|
||||
self, node_id: NodeId
|
||||
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
|
||||
if node_id not in self._vertex_indices:
|
||||
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return []
|
||||
return (
|
||||
(self._graph[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
|
||||
)
|
||||
return [
|
||||
(self._rx_id_to_node_id_map[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(
|
||||
self._node_id_to_rx_id_map[node_id]
|
||||
)
|
||||
]
|
||||
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._vertex_indices
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
|
||||
def contains_connection(self, connection: Connection) -> bool:
|
||||
return connection in self._edge_id_to_rx_id_map
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
source: NodeId,
|
||||
sink: NodeId,
|
||||
connection: SocketConnection | RDMAConnection,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
if connection in self.get_all_connections_between(source, sink):
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.local_node_id))
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
if source not in self._vertex_indices:
|
||||
self.add_node(source)
|
||||
if sink not in self._vertex_indices:
|
||||
self.add_node(sink)
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
|
||||
_ = self._graph.add_edge(src_id, sink_id, connection)
|
||||
def list_nodes(self) -> Iterable[NodeInfo]:
|
||||
return (self._graph[i] for i in self._graph.node_indices())
|
||||
|
||||
def get_all_connections_between(
|
||||
self, source: NodeId, sink: NodeId
|
||||
) -> Iterable[SocketConnection | RDMAConnection]:
|
||||
if source not in self._vertex_indices:
|
||||
return []
|
||||
if sink not in self._vertex_indices:
|
||||
return []
|
||||
def list_connections(self) -> Iterable[Connection]:
|
||||
return (connection for _, _, connection in self._graph.weighted_edge_list())
|
||||
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
|
||||
try:
|
||||
return self._graph.get_all_edge_data(src_id, sink_id)
|
||||
except rx.NoEdgeBetweenNodes:
|
||||
return []
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
return self._graph.get_node_data(rx_idx).node_profile
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def list_nodes(self) -> Iterable[NodeId]:
|
||||
return self._graph.nodes()
|
||||
def update_node_profile(
|
||||
self, node_id: NodeId, node_profile: NodePerformanceProfile
|
||||
) -> None:
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph[rx_idx].node_profile = node_profile
|
||||
|
||||
def map_connections(
|
||||
self,
|
||||
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
|
||||
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list():
|
||||
source = self._graph[src_id]
|
||||
sink = self._graph[sink_id]
|
||||
if source not in base:
|
||||
base[source] = {}
|
||||
if sink not in base[source]:
|
||||
base[source][sink] = []
|
||||
base[source][sink].append(connection)
|
||||
return base
|
||||
def update_connection_profile(self, connection: Connection) -> None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.update_edge_by_index(rx_idx, connection)
|
||||
|
||||
def list_connections(
|
||||
self,
|
||||
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
|
||||
return (
|
||||
(
|
||||
self._graph[src_id],
|
||||
self._graph[sink_id],
|
||||
connection,
|
||||
)
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list()
|
||||
)
|
||||
def get_connection_profile(
|
||||
self, connection: Connection
|
||||
) -> ConnectionProfile | None:
|
||||
try:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def remove_node(self, node_id: NodeId) -> None:
|
||||
if node_id not in self._vertex_indices:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
rx_idx = self._vertex_indices[node_id]
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_id
|
||||
or connection.send_back_node_id == node_id
|
||||
):
|
||||
self.remove_connection(connection)
|
||||
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph.remove_node(rx_idx)
|
||||
|
||||
del self._vertex_indices[node_id]
|
||||
del self._node_id_to_rx_id_map[node_id]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
|
||||
def replace_all_out_tb_connections(
|
||||
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
|
||||
) -> None:
|
||||
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
|
||||
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
for sink, conn in new_connections:
|
||||
self.add_connection(source, sink, conn)
|
||||
|
||||
def remove_connection(
|
||||
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
|
||||
) -> None:
|
||||
if source not in self._vertex_indices or sink not in self._vertex_indices:
|
||||
def remove_connection(self, connection: Connection) -> None:
|
||||
if connection not in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
for conn_idx in self._graph.edge_indices_from_endpoints(
|
||||
self._vertex_indices[source], self._vertex_indices[sink]
|
||||
):
|
||||
if self._graph.get_edge_data_by_index(conn_idx) == edge:
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.remove_edge_from_index(rx_idx)
|
||||
del self._edge_id_to_rx_id_map[connection]
|
||||
|
||||
def get_cycles(self) -> list[list[NodeId]]:
|
||||
def get_cycles(self) -> list[list[NodeInfo]]:
|
||||
cycle_idxs = rx.simple_cycles(self._graph)
|
||||
cycles: list[list[NodeId]] = []
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [self._graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def get_cycles_tb(self) -> list[list[NodeId]]:
|
||||
def get_cycles_tb(self) -> list[list[NodeInfo]]:
|
||||
tb_edges = [
|
||||
(u, v, conn)
|
||||
for u, v, conn in self._graph.weighted_edge_list()
|
||||
if conn.is_thunderbolt()
|
||||
]
|
||||
|
||||
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
if isinstance(conn, SocketConnection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycles: list[list[NodeId]] = []
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [tb_graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
|
||||
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
|
||||
node_idxs = [node.node_id for node in nodes]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
topology = Topology()
|
||||
for rx_idx in rx_idxs:
|
||||
topology.add_node(self._graph[rx_idx])
|
||||
for source, sink, connection in self.list_connections():
|
||||
if source in node_ids and sink in node_ids:
|
||||
topology.add_connection(source, sink, connection)
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id in node_idxs
|
||||
and connection.send_back_node_id in node_idxs
|
||||
):
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
|
||||
node_idxs = [node for node in cycle]
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
|
||||
node_idxs = [node.node_id for node in cycle]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
@@ -174,6 +174,7 @@ class DeleteInstanceTaskParams(BaseModel):
|
||||
class CreateInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
model_meta: ModelMetadata
|
||||
|
||||
|
||||
class DeleteInstanceResponse(BaseModel):
|
||||
|
||||
@@ -2,14 +2,14 @@ from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import SocketConnection
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
# TODO
|
||||
class NodeCreated(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
class NodeTimedOut(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
# TODO: bikeshed this naem
|
||||
class NodeGatheredInfo(BaseEvent):
|
||||
class NodePerformanceMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class NodeMemoryMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
memory: MemoryPerformanceProfile
|
||||
|
||||
|
||||
class NodeDownloadProgress(BaseEvent):
|
||||
@@ -97,15 +107,11 @@ class ChunkGenerated(BaseEvent):
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
source: NodeId
|
||||
sink: NodeId
|
||||
edge: SocketConnection
|
||||
edge: Connection
|
||||
|
||||
|
||||
class TopologyEdgeDeleted(BaseEvent):
|
||||
source: NodeId
|
||||
sink: NodeId
|
||||
edge: SocketConnection
|
||||
edge: Connection
|
||||
|
||||
|
||||
Event = (
|
||||
@@ -119,8 +125,10 @@ Event = (
|
||||
| InstanceDeleted
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
| NodeTimedOut
|
||||
| NodeGatheredInfo
|
||||
| NodePerformanceMeasured
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| TopologyEdgeCreated
|
||||
|
||||
@@ -14,3 +14,5 @@ class ModelMetadata(CamelCaseModel):
|
||||
pretty_name: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||
from pydantic import BaseModel, computed_field, field_validator
|
||||
|
||||
|
||||
class Multiaddr(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
address: str
|
||||
|
||||
PATTERNS: ClassVar[list[str]] = [
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Self
|
||||
|
||||
import psutil
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.thunderbolt import TBIdentifier
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class MemoryUsage(CamelCaseModel):
|
||||
class MemoryPerformanceProfile(CamelCaseModel):
|
||||
ram_total: Memory
|
||||
ram_available: Memory
|
||||
swap_total: Memory
|
||||
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
sys_power: float = 0.0
|
||||
pcpu_usage: float = 0.0
|
||||
ecpu_usage: float = 0.0
|
||||
ane_power: float = 0.0
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
@@ -54,16 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
model_id: str = "Unknown"
|
||||
chip_id: str = "Unknown"
|
||||
friendly_name: str = "Unknown"
|
||||
memory: MemoryUsage = MemoryUsage.from_bytes(
|
||||
ram_total=0, ram_available=0, swap_total=0, swap_available=0
|
||||
)
|
||||
network_interfaces: Sequence[NetworkInterfaceInfo] = []
|
||||
tb_interfaces: Sequence[TBIdentifier] = []
|
||||
system: SystemPerformanceProfile = SystemPerformanceProfile()
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
memory: MemoryPerformanceProfile
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
class ConnectionProfile(CamelCaseModel):
|
||||
pass
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
|
||||
@@ -40,6 +40,10 @@ class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class StartWarmup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -57,5 +61,11 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
import anyio
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class TBConnection(CamelCaseModel):
|
||||
source_uuid: str
|
||||
sink_uuid: str
|
||||
|
||||
|
||||
class TBIdentifier(CamelCaseModel):
|
||||
rdma_interface: str
|
||||
domain_uuid: str
|
||||
|
||||
|
||||
# Intentionally minimal, only collecting data we care about - there's a lot more
|
||||
|
||||
|
||||
class TBReceptacleTag(BaseModel, extra="ignore"):
|
||||
receptacle_id_key: str
|
||||
|
||||
|
||||
class TBConnectivityItem(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None
|
||||
|
||||
|
||||
class TBConnectivityData(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None
|
||||
device_name_key: str
|
||||
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
|
||||
receptacle_1_tag: TBReceptacleTag
|
||||
|
||||
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
|
||||
if self.domain_uuid_key is None:
|
||||
return
|
||||
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
|
||||
iface = f"rdma_{ifaces[tag]}"
|
||||
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
|
||||
|
||||
def conn(self) -> TBConnection | None:
|
||||
if self.domain_uuid_key is None or self.items is None:
|
||||
return
|
||||
|
||||
sink_key = next(
|
||||
item.domain_uuid_key
|
||||
for item in self.items
|
||||
if item.domain_uuid_key is not None
|
||||
)
|
||||
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
|
||||
|
||||
|
||||
class TBConnectivity(BaseModel):
|
||||
SPThunderboltDataType: list[TBConnectivityData]
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> list[TBConnectivityData] | None:
|
||||
proc = await anyio.run_process(
|
||||
["system_profiler", "SPThunderboltDataType", "-json"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
# Saving you from PascalCase while avoiding too much pydantic
|
||||
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType
|
||||
@@ -1,32 +1,37 @@
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class RDMAConnection(FrozenModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
class NodeInfo(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile | None = None
|
||||
|
||||
|
||||
class Connection(CamelCaseModel):
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
logger.warning("duh")
|
||||
return True
|
||||
|
||||
|
||||
# TODO
|
||||
class LinkType(str, Enum):
|
||||
Thunderbolt = "Thunderbolt"
|
||||
Ethernet = "Ethernet"
|
||||
WiFi = "WiFi"
|
||||
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
@@ -25,11 +25,12 @@ class BaseInstance(TaggedModel):
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts: list[Host]
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
jaccl_devices: list[list[str | None]]
|
||||
ibv_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
|
||||
43
src/exo/shared/types/worker/resource_monitor.py
Normal file
43
src/exo/shared/types/worker/resource_monitor.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from typing import Callable
|
||||
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
|
||||
class ResourceCollector(ABC):
|
||||
@abstractmethod
|
||||
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
|
||||
|
||||
|
||||
class SystemResourceCollector(ResourceCollector):
|
||||
async def collect(self) -> SystemPerformanceProfile: ...
|
||||
|
||||
|
||||
class MemoryResourceCollector(ResourceCollector):
|
||||
async def collect(self) -> MemoryPerformanceProfile: ...
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
data_collectors: list[ResourceCollector]
|
||||
effect_handlers: set[
|
||||
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
|
||||
]
|
||||
|
||||
async def _collect(
|
||||
self,
|
||||
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
|
||||
tasks: list[
|
||||
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
|
||||
] = [collector.collect() for collector in self.data_collectors]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def collect(self) -> None:
|
||||
profiles = await self._collect()
|
||||
for profile in profiles:
|
||||
for effect_handler in self.effect_handlers:
|
||||
effect_handler(profile)
|
||||
@@ -21,7 +21,15 @@ class BaseRunnerStatus(TaggedModel):
|
||||
return isinstance(self, RunnerRunning)
|
||||
|
||||
|
||||
class RunnerWaitingForModel(BaseRunnerStatus):
|
||||
class RunnerIdle(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnecting(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnected(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
@@ -54,7 +62,9 @@ class RunnerFailed(BaseRunnerStatus):
|
||||
|
||||
|
||||
RunnerStatus = (
|
||||
RunnerWaitingForModel
|
||||
RunnerIdle
|
||||
| RunnerConnecting
|
||||
| RunnerConnected
|
||||
| RunnerLoading
|
||||
| RunnerLoaded
|
||||
| RunnerWarmingUp
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tomllib
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from subprocess import CalledProcessError
|
||||
from typing import Self, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group, open_process
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_CONFIG_FILE
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
)
|
||||
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
|
||||
from exo.utils.channels import Sender
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .macmon import MacmonMetrics
|
||||
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
|
||||
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
|
||||
class StaticNodeInformation(TaggedModel):
|
||||
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||
|
||||
model: str
|
||||
chip: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
model, chip = await get_model_and_chip()
|
||||
return cls(model=model, chip=chip)
|
||||
|
||||
|
||||
class NodeNetworkInterfaces(TaggedModel):
|
||||
ifaces: Sequence[NetworkInterfaceInfo]
|
||||
|
||||
|
||||
class MacTBIdentifiers(TaggedModel):
|
||||
idents: Sequence[TBIdentifier]
|
||||
|
||||
|
||||
class MacTBConnections(TaggedModel):
|
||||
conns: Sequence[TBConnection]
|
||||
|
||||
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
# TODO
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
cfg_file = anyio.Path(EXO_CONFIG_FILE)
|
||||
await cfg_file.touch(exist_ok=True)
|
||||
async with await cfg_file.open("rb") as f:
|
||||
try:
|
||||
contents = (await f.read()).decode("utf-8")
|
||||
data = tomllib.loads(contents)
|
||||
return cls.model_validate(data)
|
||||
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Invalid config file, skipping...")
|
||||
return None
|
||||
|
||||
|
||||
class MiscData(TaggedModel):
|
||||
"""Node information that may slowly change that doesn't fall into the other categories"""
|
||||
|
||||
friendly_name: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
return cls(friendly_name=await get_friendly_name())
|
||||
|
||||
|
||||
async def _gather_iface_map() -> dict[str, str] | None:
|
||||
proc = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
|
||||
ports: dict[str, str] = {}
|
||||
port = ""
|
||||
for line in proc.stdout.decode("utf-8").split("\n"):
|
||||
if line.startswith("Hardware Port:"):
|
||||
port = line.split(": ")[1]
|
||||
elif line.startswith("Device:"):
|
||||
ports[port] = line.split(": ")[1]
|
||||
port = ""
|
||||
if "" in ports:
|
||||
del ports[""]
|
||||
return ports
|
||||
|
||||
|
||||
GatheredInfo = (
|
||||
MacmonMetrics
|
||||
| MemoryUsage
|
||||
| NodeNetworkInterfaces
|
||||
| MacTBIdentifiers
|
||||
| MacTBConnections
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfoGatherer:
|
||||
info_sender: Sender[GatheredInfo]
|
||||
interface_watcher_interval: float | None = 10
|
||||
misc_poll_interval: float | None = 60
|
||||
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
if (macmon_path := shutil.which("macmon")) is not None:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
if IS_DARWIN:
|
||||
tg.start_soon(self._monitor_system_profiler)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
|
||||
nc = await NodeConfig.gather()
|
||||
if nc is not None:
|
||||
await self.info_sender.send(nc)
|
||||
sni = await StaticNodeInformation.gather()
|
||||
await self.info_sender.send(sni)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
prev = await MiscData.gather()
|
||||
while True:
|
||||
curr = await MiscData.gather()
|
||||
if prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
|
||||
async def _monitor_system_profiler(self):
|
||||
if self.system_profiler_interval is None:
|
||||
return
|
||||
iface_map = await _gather_iface_map()
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
old_idents = []
|
||||
while True:
|
||||
data = await TBConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
if idents != old_idents:
|
||||
await self.info_sender.send(MacTBIdentifiers(idents=idents))
|
||||
old_idents = idents
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacTBConnections(conns=conns))
|
||||
|
||||
await anyio.sleep(self.system_profiler_interval)
|
||||
|
||||
async def _monitor_memory_usage(self):
|
||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||
override_memory: int | None = (
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||
if override_memory_env
|
||||
else None
|
||||
)
|
||||
if self.memory_poll_rate is None:
|
||||
return
|
||||
while True:
|
||||
await self.info_sender.send(
|
||||
MemoryUsage.from_psutil(override_memory=override_memory)
|
||||
)
|
||||
await anyio.sleep(self.memory_poll_rate)
|
||||
|
||||
async def _watch_system_info(self):
|
||||
if self.interface_watcher_interval is None:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
# macmon pipe --interval [interval in ms]
|
||||
try:
|
||||
async with await open_process(
|
||||
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
|
||||
) as p:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
if stderr_output is not None:
|
||||
stderr_msg = (
|
||||
stderr_output.decode()
|
||||
if isinstance(stderr_output, bytes)
|
||||
else str(stderr_output)
|
||||
)
|
||||
logger.warning(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class _TempMetrics(BaseModel, extra="ignore"):
|
||||
"""Temperature-related metrics returned by macmon."""
|
||||
|
||||
cpu_temp_avg: float
|
||||
gpu_temp_avg: float
|
||||
|
||||
|
||||
class _MemoryMetrics(BaseModel, extra="ignore"):
|
||||
"""Memory-related metrics returned by macmon."""
|
||||
|
||||
ram_total: int
|
||||
ram_usage: int
|
||||
swap_total: int
|
||||
swap_usage: int
|
||||
|
||||
|
||||
class RawMacmonMetrics(BaseModel, extra="ignore"):
|
||||
"""Complete set of metrics returned by macmon.
|
||||
|
||||
Unknown fields are ignored for forward-compatibility.
|
||||
"""
|
||||
|
||||
timestamp: str # ignored
|
||||
temp: _TempMetrics
|
||||
memory: _MemoryMetrics
|
||||
ecpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
pcpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
gpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
all_power: float
|
||||
ane_power: float
|
||||
cpu_power: float
|
||||
gpu_power: float
|
||||
gpu_ram_power: float
|
||||
ram_power: float
|
||||
sys_power: float
|
||||
|
||||
|
||||
class MacmonMetrics(TaggedModel):
|
||||
system_profile: SystemPerformanceProfile
|
||||
memory: MemoryUsage
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
|
||||
return cls(
|
||||
system_profile=SystemPerformanceProfile(
|
||||
gpu_usage=raw.gpu_usage[1],
|
||||
temp=raw.temp.gpu_temp_avg,
|
||||
sys_power=raw.sys_power,
|
||||
pcpu_usage=raw.pcpu_usage[1],
|
||||
ecpu_usage=raw.ecpu_usage[1],
|
||||
),
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=raw.memory.ram_total,
|
||||
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
|
||||
swap_total=raw.memory.swap_total,
|
||||
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_raw_json(cls, json: str) -> Self:
|
||||
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))
|
||||
@@ -1,56 +0,0 @@
|
||||
import socket
|
||||
from collections.abc import Mapping
|
||||
from ipaddress import ip_address
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
|
||||
|
||||
# TODO: ref. api port
|
||||
async def check_reachability(
|
||||
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
|
||||
) -> None:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1) # 1 second timeout
|
||||
try:
|
||||
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
|
||||
except socket.gaierror:
|
||||
# seems to throw on ipv6 loopback. oh well
|
||||
# logger.warning(f"invalid {target_ip=}")
|
||||
return
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
if target_node_id not in out:
|
||||
out[target_node_id] = set()
|
||||
out[target_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(
|
||||
our_node_id: NodeId,
|
||||
topology: Topology,
|
||||
profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> Mapping[NodeId, set[str]]:
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
our_profile = profiles.get(our_node_id, None)
|
||||
if our_profile is None:
|
||||
return {}
|
||||
our_interfaces = our_profile.network_interfaces
|
||||
async with create_task_group() as tg:
|
||||
for node_id in topology.list_nodes():
|
||||
if node_id not in profiles or node_id == our_node_id:
|
||||
continue
|
||||
for iface in profiles[node_id].network_interfaces:
|
||||
if ip_address(iface.ip_address).is_loopback:
|
||||
# Definitely a loopback address
|
||||
continue
|
||||
if iface in our_interfaces:
|
||||
# Skip duplicates with our own interfaces
|
||||
continue
|
||||
tg.start_soon(check_reachability, iface.ip_address, node_id, reachable)
|
||||
|
||||
return reachable
|
||||
@@ -19,20 +19,11 @@ class CamelCaseModel(BaseModel):
|
||||
alias_generator=to_camel,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
|
||||
strict=True,
|
||||
)
|
||||
|
||||
|
||||
class FrozenModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
strict=True,
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
|
||||
class TaggedModel(CamelCaseModel):
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler: SerializerFunctionWrapHandler):
|
||||
|
||||
@@ -28,8 +28,9 @@ def bar(send: MpSender[str]):
|
||||
send.close()
|
||||
|
||||
|
||||
# not async, just want the fail_after
|
||||
@pytest.mark.anyio
|
||||
async def test_channel_ipc():
|
||||
async def test_channel_setup():
|
||||
with fail_after(0.5):
|
||||
s, r = mp_channel[str]()
|
||||
p1 = mp.Process(target=foo, args=(r,))
|
||||
@@ -95,7 +95,15 @@ def extract_layer_num(tensor_name: str) -> int | None:
|
||||
|
||||
def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list[str]:
|
||||
default_patterns = set(
|
||||
["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt", "*.jinja"]
|
||||
[
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"tiktoken.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jinja",
|
||||
]
|
||||
)
|
||||
shard_specific_patterns: set[str] = set()
|
||||
if weight_map:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import copy
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
@@ -12,7 +13,7 @@ from exo.shared.types.worker.shards import (
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
class ShardDownloader(ABC):
|
||||
@abstractmethod
|
||||
async def ensure_shard(
|
||||
@@ -43,34 +44,7 @@ class ShardDownloader(ABC):
|
||||
Yields:
|
||||
tuple[Path, RepoDownloadProgress]: The path and progress of a shard download.
|
||||
"""
|
||||
yield (
|
||||
Path("/tmp/noop_shard"),
|
||||
RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
),
|
||||
)
|
||||
yield (Path("/tmp/noop_shard"), NOOP_DOWNLOAD_PROGRESS)
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_download_status_for_shard(
|
||||
@@ -94,46 +68,41 @@ class NoopShardDownloader(ShardDownloader):
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
yield (
|
||||
Path("/tmp/noop_shard"),
|
||||
RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
),
|
||||
NOOP_DOWNLOAD_PROGRESS,
|
||||
)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
return RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=shard,
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
dp = copy(NOOP_DOWNLOAD_PROGRESS)
|
||||
dp.shard = shard
|
||||
return dp
|
||||
|
||||
|
||||
NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = 8
|
||||
TEMPERATURE: float = 1.0
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -13,7 +13,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
TEMPERATURE,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -21,6 +20,8 @@ try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
import contextlib
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
@@ -48,6 +49,7 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
# Needed for 8 bit model
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
@@ -67,7 +69,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
@@ -77,7 +79,7 @@ def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
@@ -99,85 +101,97 @@ class HostList(RootModel[list[str]]):
|
||||
|
||||
def mlx_distributed_init(
|
||||
bound_instance: BoundInstance,
|
||||
) -> mx.distributed.Group:
|
||||
) -> Group:
|
||||
"""
|
||||
Initialize the MLX distributed
|
||||
Initialize MLX distributed.
|
||||
"""
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts=hosts):
|
||||
hostfile = f"./hosts_{rank}.json"
|
||||
hosts_json = HostList.from_hosts(hosts).model_dump_json()
|
||||
coordination_file = None
|
||||
try:
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
with open(hostfile, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
|
||||
logger.info(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
|
||||
logger.info(
|
||||
f"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}"
|
||||
)
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
os.environ["MLX_HOSTFILE"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
|
||||
with open(devices_file, "w") as f:
|
||||
_ = f.write(jaccl_devices_json)
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_JACCL_DEVICES: {jaccl_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = devices_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
bound_instance: BoundInstance,
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
"""
|
||||
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
|
||||
"""
|
||||
) -> Group:
|
||||
# should we unseed it?
|
||||
# TODO: pass in seed from params
|
||||
mx.random.seed(42)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (
|
||||
"Tried to initialize mlx for a single node instance"
|
||||
)
|
||||
return mlx_distributed_init(bound_instance)
|
||||
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
# TODO: pass temperature
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
logger.info("Created a sampler")
|
||||
|
||||
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
@@ -187,14 +201,12 @@ def initialize_mlx(
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
logger.debug(model)
|
||||
|
||||
return cast(Model, model), tokenizer, sampler
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: mx.distributed.Group,
|
||||
group: Group,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
|
||||
@@ -16,13 +16,15 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
@@ -31,7 +33,7 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
@@ -42,14 +44,14 @@ from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
|
||||
from exo.worker.utils.net_profile import check_reachable
|
||||
|
||||
|
||||
class Worker:
|
||||
@@ -83,7 +85,7 @@ class Worker:
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
self._nack_cancel_scope: CancelScope | None = None
|
||||
self._nack_attempts: int = 0
|
||||
@@ -95,13 +97,37 @@ class Worker:
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
# TODO: CLEANUP HEADER
|
||||
async def resource_monitor_callback(
|
||||
node_performance_profile: NodePerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodePerformanceMeasured(
|
||||
node_id=self.node_id,
|
||||
node_profile=node_performance_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
async def memory_monitor_callback(
|
||||
memory_profile: MemoryPerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeMemoryMeasured(
|
||||
node_id=self.node_id,
|
||||
memory=memory_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
)
|
||||
)
|
||||
|
||||
# END CLEANUP
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
@@ -114,17 +140,6 @@ class Worker:
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
async for info in info_stream:
|
||||
await self.event_sender.send(
|
||||
NodeGatheredInfo(
|
||||
node_id=self.node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=info,
|
||||
)
|
||||
)
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
@@ -144,6 +159,7 @@ class Worker:
|
||||
self._nack_cancel_scope is None
|
||||
or self._nack_cancel_scope.cancel_called
|
||||
):
|
||||
assert self._tg
|
||||
# Request the next index.
|
||||
self._tg.start_soon(
|
||||
self._nack_request, self.state.last_event_applied_idx + 1
|
||||
@@ -212,7 +228,7 @@ class Worker:
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.event_sender.send_nowait(
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
@@ -232,7 +248,8 @@ class Worker:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
if self._tg:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
@@ -249,24 +266,24 @@ class Worker:
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
return TopologyEdgeCreated(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
return TopologyEdgeDeleted(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
@@ -315,6 +332,7 @@ class Worker:
|
||||
event_sender=self.event_sender.clone(),
|
||||
)
|
||||
self.runners[task.bound_instance.bound_runner_id] = runner
|
||||
assert self._tg
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
@@ -373,6 +391,7 @@ class Worker:
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
assert self._tg
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
@@ -395,35 +414,33 @@ class Worker:
|
||||
while True:
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(
|
||||
self.node_id, self.state.topology, self.state.node_profiles
|
||||
)
|
||||
conns = await check_reachable(self.state.topology, self.node_id)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
edge = SocketConnection(
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=nid,
|
||||
# nonsense multiaddr
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
if "." in ip
|
||||
# nonsense multiaddr
|
||||
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
||||
)
|
||||
if edge not in edges:
|
||||
logger.debug(f"ping discovered {edge=}")
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeCreated(
|
||||
source=self.node_id, sink=nid, edge=edge
|
||||
)
|
||||
)
|
||||
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
|
||||
|
||||
for nid, conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn, SocketConnection):
|
||||
continue
|
||||
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
|
||||
nid, set()
|
||||
if (
|
||||
nid not in conns
|
||||
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
|
||||
):
|
||||
logger.debug(f"ping failed to discover {conn=}")
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
|
||||
)
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Mapping, Sequence
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
LoadModel,
|
||||
@@ -14,17 +15,23 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
@@ -48,6 +55,7 @@ def plan(
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
@@ -106,9 +114,11 @@ def _model_needs_download(
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
if (
|
||||
isinstance(runner.status, RunnerWaitingForModel)
|
||||
and runner.bound_instance.bound_shard not in download_status
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
not isinstance(
|
||||
download_status.get(runner.bound_instance.bound_shard, None),
|
||||
(DownloadOngoing, DownloadCompleted),
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
@@ -117,14 +127,54 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
""" --- TODO!
|
||||
def _init_backend(
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> LoadModel | None:
|
||||
for runner in runner.values()
|
||||
pass
|
||||
"""
|
||||
):
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance:
|
||||
continue
|
||||
|
||||
runner_is_idle = isinstance(runner.status, RunnerIdle)
|
||||
all_runners_connecting = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerConnecting, RunnerIdle),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if not (runner_is_idle and all_runners_connecting):
|
||||
continue
|
||||
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
world_size = shard.world_size
|
||||
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
accepting_ranks = device_rank < world_size - 1
|
||||
|
||||
# Rank = n-1
|
||||
connecting_rank_ready = device_rank == world_size - 1 and all(
|
||||
isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner_id
|
||||
)
|
||||
|
||||
if not (accepting_ranks or connecting_rank_ready):
|
||||
continue
|
||||
|
||||
return ConnectToGroup(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_model(
|
||||
@@ -136,31 +186,33 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
all_downloads_complete_local = all(
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
|
||||
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid, rid in shard_assignments.node_to_runner.items()
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
all_runners_expecting_model = all(
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if (
|
||||
all_downloads_complete_local
|
||||
and runner_is_waiting
|
||||
and all_runners_expecting_model
|
||||
):
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
@@ -183,8 +235,9 @@ def _ready_to_warmup(
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
# Rank != n-1
|
||||
accepting_ranks_ready = device_rank != world_size - 1 and all(
|
||||
# TODO: Ensure these align with MLX distributeds expectations.
|
||||
# Rank < n-1
|
||||
accepting_ranks_ready = device_rank < world_size - 1 and all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerLoaded, RunnerWarmingUp),
|
||||
@@ -221,6 +274,8 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
|
||||
@@ -2,16 +2,13 @@ import os
|
||||
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
|
||||
logger: "loguru.Logger"
|
||||
|
||||
|
||||
if os.getenv("EXO_TESTS") == "1":
|
||||
logger = loguru.logger
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
|
||||
def entrypoint(
|
||||
@@ -22,7 +19,7 @@ def entrypoint(
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.jaccl_devices) >= 2
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
@@ -30,6 +27,23 @@ def entrypoint(
|
||||
logger = _logger
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
from exo.worker.runner.runner import main
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=str(e)),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -22,20 +23,23 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -63,9 +67,10 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
sampler = None
|
||||
group = None
|
||||
|
||||
current_status: RunnerStatus = RunnerWaitingForModel()
|
||||
logger.info("runner waiting for model")
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
@@ -78,9 +83,26 @@ def main(
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case LoadModel() if isinstance(
|
||||
current_status, (RunnerWaitingForModel, RunnerFailed)
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected)
|
||||
and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
@@ -89,15 +111,12 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, sampler = initialize_mlx(bound_instance)
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
@@ -123,11 +142,6 @@ def main(
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
@@ -172,11 +186,6 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case Shutdown():
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
@@ -186,12 +195,19 @@ def main(
|
||||
)
|
||||
break
|
||||
case _:
|
||||
raise ValueError("Received task outside of state machine")
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
|
||||
@@ -19,8 +19,8 @@ from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
@@ -41,7 +41,7 @@ class RunnerSupervisor:
|
||||
_event_sender: Sender[Event]
|
||||
# err_path: str
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerWaitingForModel, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -24,3 +24,9 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
|
||||
|
||||
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
|
||||
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
|
||||
|
||||
SHUTDOWN_TASK_ID = TaskId("shutdown")
|
||||
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
|
||||
INITIALIZATION_TASK_ID = TaskId("initialisation")
|
||||
LOAD_TASK_ID = TaskId("load")
|
||||
WARMUP_TASK_ID = TaskId("warmup")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
@@ -14,6 +16,7 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
|
||||
# Runner supervisor without multiprocessing logic.
|
||||
@dataclass(frozen=True)
|
||||
class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
@@ -35,6 +38,8 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=2048,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
@@ -67,5 +72,27 @@ def get_mlx_ring_instance(
|
||||
shard_assignments=get_shard_assignments(
|
||||
model_id, node_to_runner, runner_to_shard
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
def get_bound_mlx_ring_instance(
|
||||
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
|
||||
) -> BoundInstance:
|
||||
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=2)
|
||||
other_shard = get_pipeline_shard_metadata(
|
||||
model_id=model_id, device_rank=1, world_size=2
|
||||
)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=instance_id,
|
||||
model_id=model_id,
|
||||
node_to_runner={
|
||||
node_id: runner_id,
|
||||
NodeId("other_node"): RunnerId("other_runner"),
|
||||
},
|
||||
runner_to_shard={runner_id: shard, RunnerId("other_runner"): other_shard},
|
||||
)
|
||||
return BoundInstance(
|
||||
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
|
||||
)
|
||||
|
||||
@@ -4,7 +4,8 @@ from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerWaitingForModel,
|
||||
RunnerConnected,
|
||||
RunnerIdle,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
@@ -38,13 +39,11 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
@@ -82,15 +81,15 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
@@ -133,13 +132,11 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
@@ -183,14 +180,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
@@ -213,3 +210,22 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
@@ -5,9 +5,9 @@ from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
COMMAND_1_ID,
|
||||
@@ -99,7 +99,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerReady(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerIdle(),
|
||||
}
|
||||
|
||||
task = ChatCompletion(
|
||||
|
||||
@@ -2,8 +2,9 @@ import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.tasks import StartWarmup
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerWaitingForModel,
|
||||
RunnerLoading,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
@@ -21,9 +22,9 @@ from exo.worker.tests.unittests.conftest import (
|
||||
)
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
"""
|
||||
For non-zero device_rank shards, StartWarmup should be emitted when all
|
||||
For non-final device_rank shards, StartWarmup should be emitted when all
|
||||
shards in the instance are Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
@@ -36,13 +37,13 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
@@ -50,10 +51,10 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
@@ -128,7 +129,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerIdle(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
@@ -149,6 +150,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
"""
|
||||
Rank-zero shard should not start warmup until all non-zero ranks are
|
||||
already WarmingUp.
|
||||
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
|
||||
emitted when all shards in the instance are Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -159,6 +163,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
@@ -173,6 +178,93 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
|
||||
"""
|
||||
For connecting rank (device_rank == world_size - 1), StartWarmup should
|
||||
only be emitted once all the other runners are already warming up.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWarmingUp(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():
|
||||
"""
|
||||
Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoading(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
@@ -184,3 +276,46 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
"""
|
||||
Connecting rank (device_rank == world_size - 1) should not start warmup
|
||||
until all other ranks are already WarmingUp.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
COMMAND_1_ID,
|
||||
INITIALIZATION_TASK_ID,
|
||||
INSTANCE_1_ID,
|
||||
LOAD_TASK_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
SHUTDOWN_TASK_ID,
|
||||
WARMUP_TASK_ID,
|
||||
)
|
||||
from ..conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T) -> Callable[[], T]:
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
nothin = make_nothin(None)
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(
|
||||
task_id=INITIALIZATION_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
LOAD_TASK = LoadModel(
|
||||
task_id=LOAD_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
WARMUP_TASK = StartWarmup(
|
||||
task_id=WARMUP_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
SHUTDOWN_TASK = Shutdown(
|
||||
task_id=SHUTDOWN_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
CHAT_PARAMS = ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=4,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
CHAT_TASK = ChatCompletion(
|
||||
task_id=CHAT_COMPLETION_TASK_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=CHAT_PARAMS,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
|
||||
for test_event, true_event in zip(test_events, true_events, strict=True):
|
||||
test_event.event_id = true_event.event_id
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NODE_A,
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
with task_sender, event_receiver:
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
|
||||
# worst monkeypatch known to man
|
||||
# this is some c++ nonsense
|
||||
event_sender.close = nothin
|
||||
event_sender.join = nothin
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||
|
||||
expected_chunk = ChunkGenerated(
|
||||
command_id=COMMAND_1_ID,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
assert_events_equal(
|
||||
events,
|
||||
[
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
# TODO:
|
||||
6
src/exo/worker/utils/__init__.py
Normal file
6
src/exo/worker/utils/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .profile import start_polling_memory_metrics, start_polling_node_metrics
|
||||
|
||||
__all__ = [
|
||||
"start_polling_node_metrics",
|
||||
"start_polling_memory_metrics",
|
||||
]
|
||||
76
src/exo/worker/utils/net_profile.py
Normal file
76
src/exo/worker/utils/net_profile.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import http.client
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
f"ip={target_ip}, expected_node_id={expected_node_id}, "
|
||||
f"remote_node_id={remote_node_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if remote_node_id not in out:
|
||||
out[remote_node_id] = set()
|
||||
out[remote_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
)
|
||||
|
||||
return reachable
|
||||
114
src/exo/worker/utils/profile.py
Normal file
114
src/exo/worker/utils/profile.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
from .macmon import (
|
||||
MacMonError,
|
||||
Metrics,
|
||||
)
|
||||
from .macmon import (
|
||||
get_metrics_async as macmon_get_metrics_async,
|
||||
)
|
||||
from .system_info import (
|
||||
get_friendly_name,
|
||||
get_model_and_chip,
|
||||
get_network_interfaces,
|
||||
)
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics | None:
|
||||
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
|
||||
|
||||
if platform.system().lower() == "darwin":
|
||||
return await macmon_get_metrics_async()
|
||||
|
||||
|
||||
def get_memory_profile() -> MemoryPerformanceProfile:
|
||||
"""Construct a MemoryPerformanceProfile using psutil"""
|
||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||
override_memory: int | None = (
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||
if override_memory_env
|
||||
else None
|
||||
)
|
||||
|
||||
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
|
||||
|
||||
|
||||
async def start_polling_memory_metrics(
|
||||
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 0.5,
|
||||
) -> None:
|
||||
"""Continuously poll and emit memory-only metrics at a faster cadence.
|
||||
|
||||
Parameters
|
||||
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
|
||||
- poll_interval_s: interval between polls
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
mem = get_memory_profile()
|
||||
await callback(mem)
|
||||
except MacMonError as e:
|
||||
logger.opt(exception=e).error("Memory Monitor encountered error")
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_node_metrics(
|
||||
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
|
||||
):
|
||||
poll_interval_s = 1.0
|
||||
while True:
|
||||
try:
|
||||
metrics = await get_metrics_async()
|
||||
if metrics is None:
|
||||
return
|
||||
|
||||
network_interfaces = get_network_interfaces()
|
||||
# these awaits could be joined but realistically they should be cached
|
||||
model_id, chip_id = await get_model_and_chip()
|
||||
friendly_name = await get_friendly_name()
|
||||
|
||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
||||
memory_profile = get_memory_profile()
|
||||
|
||||
await callback(
|
||||
NodePerformanceProfile(
|
||||
model_id=model_id,
|
||||
chip_id=chip_id,
|
||||
friendly_name=friendly_name,
|
||||
network_interfaces=network_interfaces,
|
||||
memory=memory_profile,
|
||||
system=SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
sys_power=metrics.sys_power,
|
||||
pcpu_usage=metrics.pcpu_usage[1],
|
||||
ecpu_usage=metrics.ecpu_usage[1],
|
||||
ane_power=metrics.ane_power,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
|
||||
)
|
||||
except MacMonError as e:
|
||||
logger.opt(exception=e).error("Resource Monitor encountered error")
|
||||
return
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
41
uv.lock
generated
41
uv.lock
generated
@@ -334,8 +334,10 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -374,9 +376,11 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mlx", specifier = ">=0.30.1" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = ">=0.30.1" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.30.1" },
|
||||
{ name = "mlx-lm", specifier = ">=0.28.3" },
|
||||
{ name = "networkx", specifier = ">=3.5" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "protobuf", specifier = ">=6.32.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
{ name = "pydantic", specifier = ">=2.11.7" },
|
||||
@@ -801,6 +805,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
cpu = [
|
||||
{ name = "mlx-cpu", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/64/51/32903727a68a61e972383e28a775c1f5e5f0628552c85cbc6103d68c0dc4/mlx_cpu-0.30.1-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:3f5dc2e4d0849181f8253508bb6a0854250483fc63d43ac79ec614b19824b172", size = 8992394, upload-time = "2025-12-18T00:16:13.696Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/74/69c21bb907f3c4064881ab0653029c939ae15fc4e63a5301ef8643cb1d68/mlx_cpu-0.30.1-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:c9ea6992d8c001e1123dfd3b4d4405ff576c787eec52656ad405e3d033a8be60", size = 10553055, upload-time = "2025-12-18T00:16:16.104Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.28.3"
|
||||
@@ -946,6 +964,27 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/c1/6dba12fdf68b02a21ac411c9df19afa66bed2540f467150ca64d246b463d/numpy-2.3.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e1708fac43ef8b419c975926ce1eaf793b0c13b7356cfab6ab0dc34c0a02ac0f", size = 18652691, upload-time = "2025-10-15T16:17:46.247Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openai-harmony"
|
||||
version = "0.0.8"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3e/92/2d038d096f29179c7c9571b431f9e739f87a487121901725e23fe338dd9d/openai_harmony-0.0.8.tar.gz", hash = "sha256:6e43f98e6c242fa2de6f8ea12eab24af63fa2ed3e89c06341fb9d92632c5cbdf", size = 284777, upload-time = "2025-11-05T19:07:06.727Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/45/c6/2502f416d46be3ec08bb66d696cccffb57781a499e3ff2e4d7c174af4e8f/openai_harmony-0.0.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:029ec25ca74abe48fdb58eb9fdd2a8c1618581fc33ce8e5653f8a1ffbfbd9326", size = 2627806, upload-time = "2025-11-05T19:06:57.063Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/d2/ce6953ca87db9cae3e775024184da7d1c5cb88cead19a2d75b42f00a959c/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4f709815924ec325b9a890e6ab2bbb0ceec8e319a4e257328eb752cf36b2efc", size = 2948463, upload-time = "2025-11-05T19:06:48.17Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/4c/b553c9651662d6ce102ca7f3629d268b23df1abe5841e24bed81e8a8e949/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cfcfd963b50a41fc656c84d3440ca6eecdccd6c552158ce790b8f2e33dfb5a9", size = 2704083, upload-time = "2025-11-05T19:06:50.205Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/af/4eec8f9ab9c27bcdb444460c72cf43011d176fc44c79d6e113094ca1e152/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a3a16972aa1cee38ea958470cd04ac9a2d5ac38fdcf77ab686611246220c158", size = 2959765, upload-time = "2025-11-05T19:06:53.62Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/11/3c/33f3374e4624e0e776f6b13b73c45a7ead7f9c4529f8369ed5bfcaa30cac/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4d5cfa168e74d08f8ba6d58a7e49bc7daef4d58951ec69b66b0d56f4927a68d", size = 3427031, upload-time = "2025-11-05T19:06:51.829Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/3f/1a192b93bb47c6b44cd98ba8cc1d3d2a9308f1bb700c3017e6352da11bda/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c007d277218a50db8839e599ed78e0fffe5130f614c3f6d93ae257f282071a29", size = 2953260, upload-time = "2025-11-05T19:06:55.406Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/f8/93b582cad3531797c3db7c2db5400fd841538ccddfd9f5e3df61be99a630/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8565d4f5a0638da1bffde29832ed63c9e695c558611053add3b2dc0b56c92dbc", size = 3127044, upload-time = "2025-11-05T19:06:59.553Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/10/4327dbf87f75ae813405fd9a9b4a5cde63d506ffed0a096a440a4cabd89c/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:cbaa3bda75ef0d8836e1f8cc84af62f971b1d756d740efc95c38c3e04c0bfde2", size = 2932931, upload-time = "2025-11-05T19:07:01.437Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/c8/1774eec4f6f360ef57618fb8f52e3d3af245b2491bd0297513aa09eec04b/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:772922a9bd24e133950fad71eb1550836f415a88e8c77870e12d0c3bd688ddc2", size = 2996140, upload-time = "2025-11-05T19:07:03.438Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/c3/3d1e01e2dba517a91760e4a03e4f20ffc75039a6fe584d0e6f9b5c78fd15/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:007b0476a1f331f8130783f901f1da6f5a7057af1a4891f1b6a31dec364189b5", size = 3205080, upload-time = "2025-11-05T19:07:05.078Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "25.0"
|
||||
|
||||
Reference in New Issue
Block a user