mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-30 09:40:46 -05:00
Compare commits
9 Commits
gather-lin
...
optimize-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40cbecb5c4 | ||
|
|
fea42473dd | ||
|
|
ca7adcc2a8 | ||
|
|
9d9e24f969 | ||
|
|
b5d424b658 | ||
|
|
b465134012 | ||
|
|
eabdcab978 | ||
|
|
8e9332d6a7 | ||
|
|
4b65d5f896 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,6 +7,8 @@ digest.txt
|
||||
# nix
|
||||
.direnv/
|
||||
|
||||
# IDEA (PyCharm)
|
||||
.idea
|
||||
|
||||
# xcode / macos
|
||||
*.xcuserstate
|
||||
|
||||
83
README.md
83
README.md
@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
|
||||
|
||||
There are two ways to run exo:
|
||||
|
||||
### Run from Source (Mac & Linux)
|
||||
### Run from Source (macOS)
|
||||
|
||||
**Prerequisites:**
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
@@ -98,6 +98,62 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
|
||||
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
|
||||
|
||||
**Installation methods:**
|
||||
|
||||
**Option 1: Using system package manager (Ubuntu/Debian example):**
|
||||
```bash
|
||||
# Install Node.js and npm
|
||||
sudo apt update
|
||||
sudo apt install nodejs npm
|
||||
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Option 2: Using Homebrew on Linux (if preferred):**
|
||||
```bash
|
||||
# Install Homebrew on Linux
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
|
||||
# Install dependencies
|
||||
brew install uv node
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Note:** The `macmon` package is macOS-only and not required for Linux.
|
||||
|
||||
Clone the repo, build the dashboard, and run exo:
|
||||
|
||||
```bash
|
||||
# Clone exo
|
||||
git clone https://github.com/exo-explore/exo
|
||||
|
||||
# Build dashboard
|
||||
cd exo/dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo
|
||||
uv run exo
|
||||
```
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
|
||||
|
||||
### macOS App
|
||||
|
||||
exo ships a macOS app that runs in the background on your Mac.
|
||||
@@ -112,6 +168,29 @@ The app will ask for permission to modify system settings and install a new Netw
|
||||
|
||||
---
|
||||
|
||||
### Enabling RDMA on macOS
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
1. Shut down your Mac.
|
||||
2. Hold down the power button for 10 seconds until the boot menu appears.
|
||||
3. Select "Options" to enter Recovery mode.
|
||||
4. When the Recovery UI appears, open the Terminal from the Utilities menu.
|
||||
5. In the Terminal, type:
|
||||
```
|
||||
rdma_ctl enable
|
||||
```
|
||||
and press Enter.
|
||||
6. Reboot your Mac.
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
|
||||
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
|
||||
|
||||
@@ -49,7 +49,7 @@ struct ContentView: View {
|
||||
|
||||
private var topologySection: some View {
|
||||
Group {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(), !topology.nodes.isEmpty {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
|
||||
TopologyMiniView(topology: topology)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,7 +82,6 @@ struct BugReportService {
|
||||
}
|
||||
|
||||
private func loadCredentials() throws -> AWSConfig {
|
||||
// These credentials are write-only and necessary to receive bug reports from users
|
||||
return AWSConfig(
|
||||
accessKey: "AKIAYEKP5EMXTOBYDGHX",
|
||||
secretKey: "Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE",
|
||||
|
||||
@@ -7,6 +7,7 @@ final class ClusterStateService: ObservableObject {
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var lastActionMessage: String?
|
||||
@Published private(set) var modelOptions: [ModelOption] = []
|
||||
@Published private(set) var localNodeId: String?
|
||||
|
||||
private var timer: Timer?
|
||||
private let decoder: JSONDecoder
|
||||
@@ -29,6 +30,7 @@ final class ClusterStateService: ObservableObject {
|
||||
func startPolling(interval: TimeInterval = 0.5) {
|
||||
stopPolling()
|
||||
Task {
|
||||
await fetchLocalNodeId()
|
||||
await fetchModels()
|
||||
await fetchSnapshot()
|
||||
}
|
||||
@@ -46,9 +48,31 @@ final class ClusterStateService: ObservableObject {
|
||||
latestSnapshot = nil
|
||||
lastError = nil
|
||||
lastActionMessage = nil
|
||||
localNodeId = nil
|
||||
}
|
||||
|
||||
private func fetchLocalNodeId() async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
return
|
||||
}
|
||||
if let nodeId = try? decoder.decode(String.self, from: data) {
|
||||
localNodeId = nodeId
|
||||
}
|
||||
} catch {
|
||||
// Silently ignore - localNodeId will remain nil and retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchSnapshot() async {
|
||||
// Retry fetching local node ID if not yet set
|
||||
if localNodeId == nil {
|
||||
await fetchLocalNodeId()
|
||||
}
|
||||
do {
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
|
||||
@@ -85,7 +85,7 @@ struct TopologyViewModel {
|
||||
}
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel() -> TopologyViewModel? {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
|
||||
guard !allNodes.isEmpty else { return nil }
|
||||
@@ -105,6 +105,11 @@ extension ClusterState {
|
||||
orderedNodes = allNodes
|
||||
}
|
||||
|
||||
// Rotate so the local node (from /node_id API) is first
|
||||
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
|
||||
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
|
||||
}
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
|
||||
@@ -112,10 +117,7 @@ extension ClusterState {
|
||||
} ?? []
|
||||
let edges = Set(edgesArray)
|
||||
|
||||
let topologyRootId = topology?.nodes.first?.nodeId
|
||||
let currentId = orderedNodes.first(where: { $0.id == topologyRootId })?.id ?? orderedNodes.first?.id
|
||||
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: currentId)
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
48
dashboard/package-lock.json
generated
48
dashboard/package-lock.json
generated
@@ -9,6 +9,8 @@
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -861,7 +863,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -901,7 +902,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,7 +1518,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1528,7 +1527,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1941,7 +1939,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2254,6 +2251,31 @@
|
||||
"jiti": "lib/jiti-cli.mjs"
|
||||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.27",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.27.tgz",
|
||||
"integrity": "sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==",
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
"https://github.com/sponsors/katex"
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"commander": "^8.3.0"
|
||||
},
|
||||
"bin": {
|
||||
"katex": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/katex/node_modules/commander": {
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
|
||||
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/kleur": {
|
||||
"version": "4.1.5",
|
||||
"resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz",
|
||||
@@ -2540,6 +2562,18 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.5.5"
|
||||
}
|
||||
},
|
||||
"node_modules/marked": {
|
||||
"version": "17.0.1",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz",
|
||||
"integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==",
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"marked": "bin/marked.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 20"
|
||||
}
|
||||
},
|
||||
"node_modules/mode-watcher": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz",
|
||||
@@ -2612,7 +2646,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2800,7 +2833,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2945,7 +2977,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2967,7 +2998,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -27,7 +27,8 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -198,8 +198,10 @@
|
||||
stroke: oklch(0.85 0.18 85 / 0.4);
|
||||
stroke-width: 1.5px;
|
||||
stroke-dasharray: 8, 8;
|
||||
animation: flowAnimation 1s linear infinite;
|
||||
animation: flowAnimation 1.5s linear infinite;
|
||||
filter: drop-shadow(0 0 3px oklch(0.85 0.18 85 / 0.5));
|
||||
/* GPU optimization - hint to browser this element will animate */
|
||||
will-change: stroke-dashoffset;
|
||||
}
|
||||
|
||||
.graph-link-active {
|
||||
@@ -208,6 +210,24 @@
|
||||
filter: drop-shadow(0 0 6px oklch(0.85 0.18 85 / 0.8));
|
||||
}
|
||||
|
||||
/* Reduce motion for users who prefer it - also saves GPU */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.graph-link {
|
||||
animation: none;
|
||||
}
|
||||
|
||||
.shooting-star {
|
||||
animation: none;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.status-pulse,
|
||||
.cursor-blink,
|
||||
.animate-pulse {
|
||||
animation: none;
|
||||
}
|
||||
}
|
||||
|
||||
/* CRT Screen effect for topology */
|
||||
.crt-screen {
|
||||
position: relative;
|
||||
@@ -266,13 +286,15 @@ input:focus, textarea:focus {
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
/* Shooting Stars Animation */
|
||||
/* Shooting Stars Animation - GPU optimized */
|
||||
.shooting-stars {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
overflow: hidden;
|
||||
pointer-events: none;
|
||||
z-index: 0;
|
||||
/* Only render when visible */
|
||||
content-visibility: auto;
|
||||
}
|
||||
|
||||
.shooting-star {
|
||||
@@ -285,6 +307,9 @@ input:focus, textarea:focus {
|
||||
animation: shootingStar var(--duration, 3s) linear infinite;
|
||||
animation-delay: var(--delay, 0s);
|
||||
opacity: 0;
|
||||
/* GPU optimization */
|
||||
will-change: transform, opacity;
|
||||
transform: translateZ(0);
|
||||
}
|
||||
|
||||
.shooting-star::before {
|
||||
@@ -320,3 +345,13 @@ input:focus, textarea:focus {
|
||||
transform: translate(400px, 400px);
|
||||
}
|
||||
}
|
||||
|
||||
/* Pause animations when page is hidden to save resources */
|
||||
:root:has(body[data-page-hidden="true"]) {
|
||||
.shooting-star,
|
||||
.graph-link,
|
||||
.status-pulse,
|
||||
.cursor-blink {
|
||||
animation-play-state: paused;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,89 +8,80 @@
|
||||
regenerateLastResponse
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import { tick, onDestroy } from 'svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
|
||||
// Ref for scroll anchor at bottom
|
||||
let scrollAnchorRef: HTMLDivElement | undefined = $state();
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
let showScrollButton = $state(false);
|
||||
let lastMessageCount = 0;
|
||||
let containerRef: HTMLDivElement | undefined = $state();
|
||||
|
||||
// Scroll management
|
||||
const SCROLL_BOTTOM_THRESHOLD = 120;
|
||||
let autoScrollEnabled = true;
|
||||
let currentScrollEl: HTMLElement | null = null;
|
||||
|
||||
function resolveScrollElement(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
let node: HTMLElement | null = scrollAnchorRef?.parentElement as HTMLElement | null;
|
||||
while (node) {
|
||||
const isScrollable = node.scrollHeight > node.clientHeight + 1;
|
||||
if (isScrollable) return node;
|
||||
node = node.parentElement;
|
||||
function getScrollContainer(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
return containerRef?.parentElement ?? null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (!currentScrollEl) return;
|
||||
const distanceFromBottom = currentScrollEl.scrollHeight - currentScrollEl.scrollTop - currentScrollEl.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
autoScrollEnabled = isNearBottom;
|
||||
}
|
||||
|
||||
function attachScrollListener() {
|
||||
const nextEl = resolveScrollElement();
|
||||
if (currentScrollEl === nextEl) return;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
function isNearBottom(el: HTMLElement): boolean {
|
||||
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
}
|
||||
currentScrollEl = nextEl;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.addEventListener('scroll', handleScroll);
|
||||
// Initialize state based on current position
|
||||
handleScroll();
|
||||
}
|
||||
}
|
||||
|
||||
onDestroy(() => {
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
// Re-evaluate scroll container if prop changes or after mount
|
||||
scrollParent;
|
||||
attachScrollListener();
|
||||
});
|
||||
|
||||
// Auto-scroll to bottom when messages change or response updates, but only if user is near bottom
|
||||
$effect(() => {
|
||||
// Track these values to trigger effect
|
||||
const _ = messageList.length;
|
||||
const __ = response;
|
||||
const ___ = loading;
|
||||
|
||||
tick().then(() => {
|
||||
const el = currentScrollEl ?? resolveScrollElement();
|
||||
if (!el || !scrollAnchorRef) return;
|
||||
const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
if (autoScrollEnabled || isNearBottom) {
|
||||
scrollAnchorRef.scrollIntoView({ behavior: 'smooth', block: 'end' });
|
||||
autoScrollEnabled = true;
|
||||
function scrollToBottom() {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
}
|
||||
}
|
||||
|
||||
function updateScrollButtonVisibility() {
|
||||
const el = getScrollContainer();
|
||||
if (!el) return;
|
||||
showScrollButton = !isNearBottom(el);
|
||||
}
|
||||
|
||||
// Attach scroll listener
|
||||
$effect(() => {
|
||||
const el = scrollParent ?? containerRef?.parentElement;
|
||||
if (!el) return;
|
||||
|
||||
el.addEventListener('scroll', updateScrollButtonVisibility, { passive: true });
|
||||
// Initial check
|
||||
updateScrollButtonVisibility();
|
||||
return () => el.removeEventListener('scroll', updateScrollButtonVisibility);
|
||||
});
|
||||
|
||||
// Auto-scroll when user sends a new message
|
||||
$effect(() => {
|
||||
const count = messageList.length;
|
||||
if (count > lastMessageCount) {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
requestAnimationFrame(() => {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
});
|
||||
}
|
||||
}
|
||||
lastMessageCount = count;
|
||||
});
|
||||
|
||||
// Update scroll button visibility when content changes
|
||||
$effect(() => {
|
||||
// Track response to trigger re-check during streaming
|
||||
const _ = response;
|
||||
|
||||
// Small delay to let DOM update
|
||||
requestAnimationFrame(() => updateScrollButtonVisibility());
|
||||
});
|
||||
});
|
||||
|
||||
// Edit state
|
||||
let editingMessageId = $state<string | null>(null);
|
||||
@@ -231,7 +222,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
<div class="group flex {message.role === 'user' ? 'justify-end' : 'justify-start'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'max-w-[95%] sm:max-w-[85%]'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'w-full max-w-[98%] sm:max-w-[95%]'}">
|
||||
{#if message.role === 'assistant'}
|
||||
<!-- Assistant message header -->
|
||||
<div class="flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2">
|
||||
@@ -305,7 +296,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{:else}
|
||||
<div class="{message.role === 'user'
|
||||
? 'command-panel rounded-lg rounded-tr-sm inline-block'
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 inline-block'}">
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full'}">
|
||||
|
||||
{#if message.role === 'user'}
|
||||
<!-- User message styling -->
|
||||
@@ -331,7 +322,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
|
||||
{#if message.content}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
<div class="text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -360,7 +351,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
<span>Thinking...</span>
|
||||
</span>
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60">
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4">
|
||||
{isThinkingExpanded(message.id) ? 'HIDE' : 'SHOW'}
|
||||
</span>
|
||||
</button>
|
||||
@@ -374,8 +365,8 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content || (loading ? response : '')}
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
@@ -457,6 +448,20 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Scroll anchor for auto-scroll -->
|
||||
<div bind:this={scrollAnchorRef}></div>
|
||||
<!-- Invisible element for container reference -->
|
||||
<div bind:this={containerRef}></div>
|
||||
|
||||
<!-- Scroll to bottom button -->
|
||||
{#if showScrollButton}
|
||||
<button
|
||||
type="button"
|
||||
onclick={scrollToBottom}
|
||||
class="sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10"
|
||||
title="Scroll to bottom"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,9 @@ import {
|
||||
clearChat,
|
||||
instances,
|
||||
debugMode,
|
||||
toggleDebugMode
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode
|
||||
} from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -23,6 +25,7 @@ import {
|
||||
const activeId = $derived(activeConversationId());
|
||||
const instanceData = $derived(instances());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
|
||||
let searchQuery = $state('');
|
||||
let editingId = $state<string | null>(null);
|
||||
@@ -424,6 +427,19 @@ const debugEnabled = $derived(debugMode());
|
||||
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
|
||||
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title="Toggle topology only mode"
|
||||
>
|
||||
<svg class="w-4 h-4 {topologyOnlyEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
export let showHome = true;
|
||||
export let onHome: (() => void) | null = null;
|
||||
export let showSidebarToggle = false;
|
||||
export let sidebarVisible = true;
|
||||
export let onToggleSidebar: (() => void) | null = null;
|
||||
|
||||
function handleHome(): void {
|
||||
if (onHome) {
|
||||
@@ -14,13 +17,38 @@
|
||||
window.location.hash = '/';
|
||||
}
|
||||
}
|
||||
|
||||
function handleToggleSidebar(): void {
|
||||
if (onToggleSidebar) {
|
||||
onToggleSidebar();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<header class="relative z-20 flex items-center justify-center px-6 pt-8 pb-4 bg-exo-dark-gray">
|
||||
<!-- Left: Sidebar Toggle -->
|
||||
{#if showSidebarToggle}
|
||||
<div class="absolute left-6 top-1/2 -translate-y-1/2">
|
||||
<button
|
||||
onclick={handleToggleSidebar}
|
||||
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title={sidebarVisible ? 'Hide sidebar' : 'Show sidebar'}
|
||||
>
|
||||
<svg class="w-5 h-5 {sidebarVisible ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
{#if sidebarVisible}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M11 19l-7-7 7-7m8 14l-7-7 7-7" />
|
||||
{:else}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M13 5l7 7-7 7M5 5l7 7-7 7" />
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Center: Logo (clickable to go home) -->
|
||||
<button
|
||||
onclick={handleHome}
|
||||
class="hover:opacity-80 transition-opacity {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
class="bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
title={showHome ? 'Go to home' : ''}
|
||||
disabled={!showHome}
|
||||
>
|
||||
|
||||
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
@@ -0,0 +1,451 @@
|
||||
<script lang="ts">
|
||||
import { marked } from 'marked';
|
||||
import hljs from 'highlight.js';
|
||||
import katex from 'katex';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { content, class: className = '' }: Props = $props();
|
||||
|
||||
let containerRef = $state<HTMLDivElement>();
|
||||
let processedHtml = $state('');
|
||||
|
||||
// Configure marked with syntax highlighting
|
||||
marked.setOptions({
|
||||
gfm: true,
|
||||
breaks: true
|
||||
});
|
||||
|
||||
// Custom renderer for code blocks
|
||||
const renderer = new marked.Renderer();
|
||||
|
||||
renderer.code = function ({ text, lang }: { text: string; lang?: string }) {
|
||||
const language = lang && hljs.getLanguage(lang) ? lang : 'plaintext';
|
||||
const highlighted = hljs.highlight(text, { language }).value;
|
||||
const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
|
||||
|
||||
return `
|
||||
<div class="code-block-wrapper">
|
||||
<div class="code-block-header">
|
||||
<span class="code-language">${language}</span>
|
||||
<button type="button" class="copy-code-btn" data-code="${encodeURIComponent(text)}" title="Copy code">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<pre><code class="hljs language-${language}" data-code-id="${codeId}">${highlighted}</code></pre>
|
||||
</div>
|
||||
`;
|
||||
};
|
||||
|
||||
// Inline code
|
||||
renderer.codespan = function ({ text }: { text: string }) {
|
||||
return `<code class="inline-code">${text}</code>`;
|
||||
};
|
||||
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
function processMarkdown(text: string): string {
|
||||
try {
|
||||
// Preprocess LaTeX notation
|
||||
const preprocessed = preprocessLaTeX(text);
|
||||
// Parse markdown
|
||||
let html = marked.parse(preprocessed) as string;
|
||||
// Render math expressions
|
||||
html = renderMath(html);
|
||||
return html;
|
||||
} catch (error) {
|
||||
console.error('Markdown processing error:', error);
|
||||
return text.replace(/\n/g, '<br>');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedCode = target.getAttribute('data-code');
|
||||
if (!encodedCode) return;
|
||||
|
||||
const code = decodeURIComponent(encodedCode);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(code);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (content) {
|
||||
processedHtml = processMarkdown(content);
|
||||
} else {
|
||||
processedHtml = '';
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (containerRef && processedHtml) {
|
||||
setupCopyButtons();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="markdown-content {className}">
|
||||
{@html processedHtml}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.markdown-content {
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
/* Paragraphs */
|
||||
.markdown-content :global(p) {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Headers */
|
||||
.markdown-content :global(h1) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h2) {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h3) {
|
||||
font-size: 1.125rem;
|
||||
font-weight: 600;
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(h4),
|
||||
.markdown-content :global(h5),
|
||||
.markdown-content :global(h6) {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
margin: 0.75rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Bold and italic */
|
||||
.markdown-content :global(strong) {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(em) {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Inline code */
|
||||
.markdown-content :global(.inline-code) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.25rem;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
.markdown-content :global(a) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
text-decoration: underline;
|
||||
text-underline-offset: 2px;
|
||||
}
|
||||
|
||||
.markdown-content :global(a:hover) {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Lists */
|
||||
.markdown-content :global(ul) {
|
||||
list-style-type: disc;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(ol) {
|
||||
list-style-type: decimal;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li) {
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li::marker) {
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
}
|
||||
|
||||
/* Blockquotes */
|
||||
.markdown-content :global(blockquote) {
|
||||
border-left: 3px solid var(--exo-yellow, #ffd700);
|
||||
padding: 0.5rem 1rem;
|
||||
margin: 1rem 0;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-radius: 0 0.25rem 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.markdown-content :global(table) {
|
||||
width: 100%;
|
||||
margin: 1rem 0;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(th) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(td) {
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
padding: 0.5rem;
|
||||
}
|
||||
|
||||
/* Horizontal rule */
|
||||
.markdown-content :global(hr) {
|
||||
border: none;
|
||||
border-top: 1px solid rgba(255, 255, 255, 0.1);
|
||||
margin: 1.5rem 0;
|
||||
}
|
||||
|
||||
/* Code block wrapper */
|
||||
.markdown-content :global(.code-block-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
background: rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.1);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-language) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
font-size: 0.7rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper pre) {
|
||||
margin: 0;
|
||||
padding: 1rem;
|
||||
overflow-x: auto;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper code) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.5;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Syntax highlighting - dark theme matching EXO style */
|
||||
.markdown-content :global(.hljs) {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-keyword),
|
||||
.markdown-content :global(.hljs-selector-tag),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-section),
|
||||
.markdown-content :global(.hljs-link) {
|
||||
color: #c084fc;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-string),
|
||||
.markdown-content :global(.hljs-title),
|
||||
.markdown-content :global(.hljs-name),
|
||||
.markdown-content :global(.hljs-type),
|
||||
.markdown-content :global(.hljs-attribute),
|
||||
.markdown-content :global(.hljs-symbol),
|
||||
.markdown-content :global(.hljs-bullet),
|
||||
.markdown-content :global(.hljs-addition),
|
||||
.markdown-content :global(.hljs-variable),
|
||||
.markdown-content :global(.hljs-template-tag),
|
||||
.markdown-content :global(.hljs-template-variable) {
|
||||
color: #fbbf24;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-comment),
|
||||
.markdown-content :global(.hljs-quote),
|
||||
.markdown-content :global(.hljs-deletion),
|
||||
.markdown-content :global(.hljs-meta) {
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-number),
|
||||
.markdown-content :global(.hljs-regexp),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-built_in) {
|
||||
color: #34d399;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-function),
|
||||
.markdown-content :global(.hljs-class .hljs-title) {
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
margin: 1rem 0;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error) {
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview } from '$lib/stores/app.svelte';
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview, TopologyEdge } from '$lib/stores/app.svelte';
|
||||
import { debugMode, topologyData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
model: { id: string; name?: string; storage_size_megabytes?: number };
|
||||
@@ -206,12 +207,8 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
const centerY = topoHeight / 2;
|
||||
const radius = numNodes === 1 ? 0 : numNodes === 2 ? 45 : Math.min(topoWidth, topoHeight) * 0.32;
|
||||
|
||||
// Use API preview data if available
|
||||
// Only use API preview data - no local estimation
|
||||
const hasApiPreview = apiPreview !== null && apiPreview.error === null && apiPreview.memory_delta_by_node !== null;
|
||||
const canFit = hasApiPreview ? true : (() => {
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return totalAvailable >= estimatedMemory;
|
||||
})();
|
||||
const error = apiPreview?.error ?? null;
|
||||
|
||||
let placementNodes: Array<{
|
||||
@@ -232,135 +229,140 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
modelFillHeight: number;
|
||||
}> = [];
|
||||
|
||||
if (hasApiPreview && apiPreview.memory_delta_by_node) {
|
||||
// Use API placement data
|
||||
const memoryDelta = apiPreview.memory_delta_by_node;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else if (apiPreview?.error) {
|
||||
// API returned an error - model can't fit, show all nodes as unused
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: 0,
|
||||
currentPercent,
|
||||
newPercent: currentPercent,
|
||||
isUsed: false,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: 0
|
||||
};
|
||||
});
|
||||
} else {
|
||||
// Fallback: local estimation based on sharding strategy
|
||||
const memoryNeeded = estimatedMemory;
|
||||
// Use API placement data directly
|
||||
const memoryDelta = apiPreview?.memory_delta_by_node ?? {};
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
if (sharding === 'Pipeline') {
|
||||
const memoryPerNode = memoryNeeded / numNodes;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + memoryPerNode) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: memoryPerNode,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed: true,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else {
|
||||
let remaining = memoryNeeded;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const allocated = Math.min(remaining, n.availableGB);
|
||||
remaining -= allocated;
|
||||
const isUsed = allocated > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + allocated) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: allocated,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return { nodes: placementNodes, canFit: hasApiPreview || canFit, totalAvailable, topoWidth, topoHeight, error };
|
||||
return { nodes: placementNodes, canFit: hasApiPreview, totalAvailable, topoWidth, topoHeight, error };
|
||||
});
|
||||
|
||||
const canFit = $derived(apiPreview ? apiPreview.error === null : placementPreview().canFit);
|
||||
const placementError = $derived(apiPreview?.error ?? null);
|
||||
const nodeCount = $derived(nodeList().length);
|
||||
const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, ''));
|
||||
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === 'MlxIbv' || runtime === 'MlxJaccl');
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
if (!ip || !topology?.nodes) return null;
|
||||
|
||||
// Strip port if present
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Check specified node first
|
||||
const node = topology.nodes[nodeId];
|
||||
if (node) {
|
||||
const match = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
// Fallback: check all nodes
|
||||
for (const [, otherNode] of Object.entries(topology.nodes)) {
|
||||
if (!otherNode) continue;
|
||||
const match = otherNode.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get directional arrow based on node positions
|
||||
function getArrow(fromNode: { x: number; y: number }, toNode: { x: number; y: number }): string {
|
||||
const dx = toNode.x - fromNode.x;
|
||||
const dy = toNode.y - fromNode.y;
|
||||
const absX = Math.abs(dx);
|
||||
const absY = Math.abs(dy);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dx > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dy > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dx > 0 && dy > 0) return '↘';
|
||||
if (dx > 0 && dy < 0) return '↗';
|
||||
if (dx < 0 && dy > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Get connection info for edges between two nodes
|
||||
// Returns exactly one connection per direction (A→B and B→A), preferring non-loopback
|
||||
function getConnectionInfo(nodeId1: string, nodeId2: string): Array<{ ip: string; iface: string | null; from: string; to: string }> {
|
||||
if (!topology?.edges) return [];
|
||||
|
||||
// Collect candidates for each direction
|
||||
const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
|
||||
for (const edge of topology.edges) {
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
|
||||
|
||||
if (edge.source === nodeId1 && edge.target === nodeId2) {
|
||||
aToBCandidates.push({ ip, iface });
|
||||
} else if (edge.source === nodeId2 && edge.target === nodeId1) {
|
||||
bToACandidates.push({ ip, iface });
|
||||
}
|
||||
}
|
||||
|
||||
// Pick best (prefer non-loopback)
|
||||
const pickBest = (candidates: Array<{ ip: string; iface: string | null }>) => {
|
||||
if (candidates.length === 0) return null;
|
||||
return candidates.find(c => !c.ip.startsWith('127.')) || candidates[0];
|
||||
};
|
||||
|
||||
const result: Array<{ ip: string; iface: string | null; from: string; to: string }> = [];
|
||||
|
||||
const bestAtoB = pickBest(aToBCandidates);
|
||||
if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });
|
||||
|
||||
const bestBtoA = pickBest(bToACandidates);
|
||||
if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });
|
||||
|
||||
return result;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative group">
|
||||
@@ -453,6 +455,26 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
|
||||
<!-- Connection lines between nodes (if multiple) -->
|
||||
{#if preview.nodes.length > 1}
|
||||
{@const usedNodes = preview.nodes.filter(n => n.isUsed)}
|
||||
{@const nodePositions = Object.fromEntries(preview.nodes.map(n => [n.id, { x: n.x, y: n.y }]))}
|
||||
{@const allConnections = isDebugMode && usedNodes.length > 1 ? (() => {
|
||||
const conns: Array<{ ip: string; iface: string | null; from: string; to: string; midX: number; midY: number; arrow: string }> = [];
|
||||
for (let i = 0; i < usedNodes.length; i++) {
|
||||
for (let j = i + 1; j < usedNodes.length; j++) {
|
||||
const n1 = usedNodes[i];
|
||||
const n2 = usedNodes[j];
|
||||
const midX = (n1.x + n2.x) / 2;
|
||||
const midY = (n1.y + n2.y) / 2;
|
||||
for (const c of getConnectionInfo(n1.id, n2.id)) {
|
||||
const fromPos = nodePositions[c.from];
|
||||
const toPos = nodePositions[c.to];
|
||||
const arrow = fromPos && toPos ? getArrow(fromPos, toPos) : '→';
|
||||
conns.push({ ...c, midX, midY, arrow });
|
||||
}
|
||||
}
|
||||
}
|
||||
return conns;
|
||||
})() : []}
|
||||
{#each preview.nodes as node, i}
|
||||
{#each preview.nodes.slice(i + 1) as node2}
|
||||
<line
|
||||
@@ -464,6 +486,43 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
/>
|
||||
{/each}
|
||||
{/each}
|
||||
<!-- Debug: Show connection IPs/interfaces in corners -->
|
||||
{#if isDebugMode && allConnections.length > 0}
|
||||
{@const centerX = preview.topoWidth / 2}
|
||||
{@const centerY = preview.topoHeight / 2}
|
||||
{@const quadrants = {
|
||||
topLeft: allConnections.filter(c => c.midX < centerX && c.midY < centerY),
|
||||
topRight: allConnections.filter(c => c.midX >= centerX && c.midY < centerY),
|
||||
bottomLeft: allConnections.filter(c => c.midX < centerX && c.midY >= centerY),
|
||||
bottomRight: allConnections.filter(c => c.midX >= centerX && c.midY >= centerY)
|
||||
}}
|
||||
{@const padding = 4}
|
||||
{@const lineHeight = 8}
|
||||
<!-- Top Left -->
|
||||
{#each quadrants.topLeft as conn, idx}
|
||||
<text x={padding} y={padding + idx * lineHeight} text-anchor="start" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Top Right -->
|
||||
{#each quadrants.topRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={padding + idx * lineHeight} text-anchor="end" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Left -->
|
||||
{#each quadrants.bottomLeft as conn, idx}
|
||||
<text x={padding} y={preview.topoHeight - padding - (quadrants.bottomLeft.length - 1 - idx) * lineHeight} text-anchor="start" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Right -->
|
||||
{#each quadrants.bottomRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={preview.topoHeight - padding - (quadrants.bottomRight.length - 1 - idx) * lineHeight} text-anchor="end" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#each preview.nodes as node}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from 'svelte';
|
||||
import { onMount, onDestroy, tick } from 'svelte';
|
||||
import * as d3 from 'd3';
|
||||
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
|
||||
|
||||
@@ -12,11 +12,35 @@ import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.sv
|
||||
|
||||
let svgContainer: SVGSVGElement | undefined = $state();
|
||||
let resizeObserver: ResizeObserver | undefined;
|
||||
|
||||
// Optimization: Track last render state to avoid unnecessary re-renders
|
||||
let lastRenderHash = '';
|
||||
let lastHighlightedNodesHash = '';
|
||||
let lastDimensions = { width: 0, height: 0 };
|
||||
let isRendering = false;
|
||||
let pendingRender = false;
|
||||
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
|
||||
// Generate a hash of relevant data to detect actual changes
|
||||
function generateDataHash(topologyData: typeof data, minimized: boolean, debug: boolean): string {
|
||||
if (!topologyData) return 'null';
|
||||
const nodes = topologyData.nodes || {};
|
||||
const edges = topologyData.edges || [];
|
||||
|
||||
// Create a lightweight hash from key properties only
|
||||
const nodeHashes = Object.entries(nodes).map(([id, n]) => {
|
||||
const macmon = n.macmon_info;
|
||||
return `${id}:${n.friendly_name || ''}:${macmon?.memory?.ram_usage || 0}:${macmon?.memory?.ram_total || 0}:${macmon?.temp?.gpu_temp_avg || 0}:${macmon?.gpu_usage?.[1] || 0}:${macmon?.sys_power || 0}`;
|
||||
}).sort().join('|');
|
||||
|
||||
const edgeHash = edges.map(e => `${e.source}-${e.target}`).sort().join(',');
|
||||
|
||||
return `${nodeHashes}::${edgeHash}::${minimized}::${debug}`;
|
||||
}
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
@@ -24,19 +48,36 @@ function getNodeLabel(nodeId: string): string {
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
const node = data?.nodes?.[nodeId];
|
||||
if (!node) return { label: '?', missing: true };
|
||||
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: typeof data.nodes[string]): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return { label: matchFromInterfaces.name, missing: false };
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const mapped = node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return { label: mapped, missing: false };
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
if (otherResult) return { label: otherResult, missing: false };
|
||||
}
|
||||
|
||||
return { label: '?', missing: true };
|
||||
@@ -67,6 +108,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
return lines;
|
||||
}
|
||||
|
||||
|
||||
// Apple logo path for MacBook Pro screen
|
||||
const APPLE_LOGO_PATH = "M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z";
|
||||
const LOGO_NATIVE_WIDTH = 814;
|
||||
@@ -238,6 +280,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
@@ -314,110 +357,98 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('marker-end', 'url(#arrowhead)');
|
||||
}
|
||||
|
||||
// Collect debug labels for later positioning at edges
|
||||
if (debugEnabled && entry.connections.length > 0) {
|
||||
const maxBoxes = 6;
|
||||
const fontSize = isMinimized ? 8 : 9;
|
||||
const lineGap = 2;
|
||||
const labelOffsetOut = Math.max(140, minDimension * 0.38);
|
||||
const labelOffsetSide = isMinimized ? 16 : 20;
|
||||
const boxWidth = 170;
|
||||
const maxLineLen = 26;
|
||||
|
||||
const connections = entry.connections.slice(0, maxBoxes);
|
||||
if (entry.connections.length > maxBoxes) {
|
||||
const remaining = entry.connections.length - maxBoxes;
|
||||
connections.push({
|
||||
from: '',
|
||||
to: '',
|
||||
ip: `(+${remaining} more)`,
|
||||
ifaceLabel: '',
|
||||
missingIface: false
|
||||
});
|
||||
}
|
||||
|
||||
let dirX = mx - centerX;
|
||||
let dirY = my - centerY;
|
||||
const dirLen = Math.hypot(dirX, dirY);
|
||||
if (dirLen < 1) {
|
||||
dirX = -uy;
|
||||
dirY = ux;
|
||||
} else {
|
||||
dirX /= dirLen;
|
||||
dirY /= dirLen;
|
||||
}
|
||||
|
||||
const nx = -dirY;
|
||||
const ny = dirX;
|
||||
|
||||
const labelXRaw = mx + dirX * labelOffsetOut + nx * labelOffsetSide;
|
||||
const labelYRaw = my + dirY * labelOffsetOut + ny * labelOffsetSide;
|
||||
const clampPad = Math.min(120, minDimension * 0.12);
|
||||
const labelX = Math.max(clampPad, Math.min(width - clampPad, labelXRaw));
|
||||
const labelY = Math.max(clampPad, Math.min(height - clampPad, labelYRaw));
|
||||
|
||||
const labelGroup = debugLabelsGroup.append('g')
|
||||
.attr('transform', `translate(${labelX}, ${labelY})`);
|
||||
|
||||
const textGroup = labelGroup.append('g');
|
||||
|
||||
connections.forEach((conn, idx) => {
|
||||
const rawLines = conn.from && conn.to
|
||||
? [
|
||||
`${getNodeLabel(conn.from)}→${getNodeLabel(conn.to)}`,
|
||||
`${conn.ip}`,
|
||||
`${conn.ifaceLabel}`
|
||||
]
|
||||
: [conn.ip];
|
||||
|
||||
const wrapped = rawLines.flatMap(line => wrapLine(line, maxLineLen));
|
||||
|
||||
wrapped.forEach((line, lineIdx) => {
|
||||
textGroup.append('text')
|
||||
.attr('x', 0)
|
||||
.attr('y', (idx * (wrapped.length * (fontSize + lineGap))) + lineIdx * (fontSize + lineGap))
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('dominant-baseline', 'hanging')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.9)')
|
||||
.text(line);
|
||||
});
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
isTop,
|
||||
mx,
|
||||
my
|
||||
});
|
||||
|
||||
const bbox = textGroup.node()?.getBBox();
|
||||
if (bbox) {
|
||||
const paddedWidth = Math.max(boxWidth, bbox.width + 14);
|
||||
const boxHeight = bbox.height + 8;
|
||||
const boxMinX = labelX - paddedWidth / 2;
|
||||
const boxMaxX = labelX + paddedWidth / 2;
|
||||
const boxMinY = labelY + bbox.y - 4;
|
||||
const boxMaxY = boxMinY + boxHeight;
|
||||
|
||||
const clampPadDynamic = Math.min(140, minDimension * 0.18);
|
||||
let shiftX = 0;
|
||||
let shiftY = 0;
|
||||
if (boxMinX < clampPadDynamic) shiftX = clampPadDynamic - boxMinX;
|
||||
if (boxMaxX > width - clampPadDynamic) shiftX = (width - clampPadDynamic) - boxMaxX;
|
||||
if (boxMinY < clampPadDynamic) shiftY = clampPadDynamic - boxMinY;
|
||||
if (boxMaxY > height - clampPadDynamic) shiftY = (height - clampPadDynamic) - boxMaxY;
|
||||
|
||||
const finalX = labelX + shiftX;
|
||||
const finalY = labelY + shiftY;
|
||||
labelGroup.attr('transform', `translate(${finalX}, ${finalY})`);
|
||||
|
||||
labelGroup.insert('rect', 'g')
|
||||
.attr('x', -paddedWidth / 2)
|
||||
.attr('y', bbox.y - 4)
|
||||
.attr('width', paddedWidth)
|
||||
.attr('height', boxHeight)
|
||||
.attr('rx', 4)
|
||||
.attr('fill', 'rgba(0,0,0,0.75)')
|
||||
.attr('stroke', 'rgba(255,255,255,0.12)')
|
||||
.attr('stroke-width', 0.6);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Render debug labels at viewport edges/corners
|
||||
if (debugEdgeLabels && debugEdgeLabels.length > 0) {
|
||||
const fontSize = isMinimized ? 10 : 12;
|
||||
const lineHeight = fontSize + 4;
|
||||
const padding = 10;
|
||||
|
||||
// Helper to get arrow based on direction vector
|
||||
function getArrow(fromId: string, toId: string): string {
|
||||
const fromPos = positionById[fromId];
|
||||
const toPos = positionById[toId];
|
||||
if (!fromPos || !toPos) return '→';
|
||||
|
||||
const dirX = toPos.x - fromPos.x;
|
||||
const dirY = toPos.y - fromPos.y;
|
||||
const absX = Math.abs(dirX);
|
||||
const absY = Math.abs(dirY);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dirX > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dirY > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dirX > 0 && dirY > 0) return '↘';
|
||||
if (dirX > 0 && dirY < 0) return '↗';
|
||||
if (dirX < 0 && dirY > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
edges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
debugLabelsGroup.append('text')
|
||||
.attr('x', baseX)
|
||||
.attr('y', currentY)
|
||||
.attr('text-anchor', textAnchor)
|
||||
.attr('dominant-baseline', isTop ? 'hanging' : 'auto')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.85)')
|
||||
.text(label);
|
||||
currentY += isTop ? lineHeight : -lineHeight;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Draw nodes
|
||||
const nodesGroup = svg.append('g').attr('class', 'nodes-group');
|
||||
|
||||
@@ -925,16 +956,59 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (data) {
|
||||
// Throttled render function to prevent too-frequent updates
|
||||
function scheduleRender() {
|
||||
if (isRendering) {
|
||||
pendingRender = true;
|
||||
return;
|
||||
}
|
||||
|
||||
isRendering = true;
|
||||
requestAnimationFrame(() => {
|
||||
renderGraph();
|
||||
isRendering = false;
|
||||
|
||||
if (pendingRender) {
|
||||
pendingRender = false;
|
||||
scheduleRender();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!data || !svgContainer) return;
|
||||
|
||||
// Generate hash of current state
|
||||
const currentHash = generateDataHash(data, isMinimized, debugEnabled);
|
||||
const highlightHash = Array.from(highlightedNodes).sort().join(',');
|
||||
|
||||
// Get current dimensions
|
||||
const rect = svgContainer.getBoundingClientRect();
|
||||
const dimensionsChanged = rect.width !== lastDimensions.width || rect.height !== lastDimensions.height;
|
||||
|
||||
// Only re-render if something actually changed
|
||||
if (currentHash !== lastRenderHash || highlightHash !== lastHighlightedNodesHash || dimensionsChanged) {
|
||||
lastRenderHash = currentHash;
|
||||
lastHighlightedNodesHash = highlightHash;
|
||||
lastDimensions = { width: rect.width, height: rect.height };
|
||||
scheduleRender();
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
if (svgContainer) {
|
||||
// Use a debounced resize observer to prevent rapid re-renders
|
||||
let resizeTimeout: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
resizeObserver = new ResizeObserver(() => {
|
||||
renderGraph();
|
||||
if (resizeTimeout) clearTimeout(resizeTimeout);
|
||||
resizeTimeout = setTimeout(() => {
|
||||
const rect = svgContainer!.getBoundingClientRect();
|
||||
if (rect.width !== lastDimensions.width || rect.height !== lastDimensions.height) {
|
||||
lastDimensions = { width: rect.width, height: rect.height };
|
||||
scheduleRender();
|
||||
}
|
||||
}, 100);
|
||||
});
|
||||
resizeObserver.observe(svgContainer);
|
||||
}
|
||||
@@ -962,10 +1036,20 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
stroke-width: 1px;
|
||||
stroke-dasharray: 4, 4;
|
||||
opacity: 0.8;
|
||||
animation: flowAnimation 0.75s linear infinite;
|
||||
/* Slower animation = less GPU usage */
|
||||
animation: flowAnimation 2s linear infinite;
|
||||
/* GPU optimization */
|
||||
will-change: stroke-dashoffset;
|
||||
}
|
||||
@keyframes flowAnimation {
|
||||
from { stroke-dashoffset: 0; }
|
||||
to { stroke-dashoffset: -10; }
|
||||
}
|
||||
|
||||
/* Respect reduced motion preference */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
:global(.graph-link) {
|
||||
animation: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -4,4 +4,5 @@ export { default as ChatMessages } from './ChatMessages.svelte';
|
||||
export { default as ChatAttachments } from './ChatAttachments.svelte';
|
||||
export { default as ChatSidebar } from './ChatSidebar.svelte';
|
||||
export { default as ModelCard } from './ModelCard.svelte';
|
||||
export { default as MarkdownContent } from './MarkdownContent.svelte';
|
||||
|
||||
|
||||
@@ -297,6 +297,35 @@ function extractIpFromMultiaddr(ma?: string): string | undefined {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Deep comparison utility for preventing unnecessary state updates
|
||||
function shallowEqual(a: unknown, b: unknown): boolean {
|
||||
if (a === b) return true;
|
||||
if (a === null || b === null) return false;
|
||||
if (typeof a !== 'object' || typeof b !== 'object') return false;
|
||||
|
||||
const aObj = a as Record<string, unknown>;
|
||||
const bObj = b as Record<string, unknown>;
|
||||
const aKeys = Object.keys(aObj);
|
||||
const bKeys = Object.keys(bObj);
|
||||
|
||||
if (aKeys.length !== bKeys.length) return false;
|
||||
|
||||
for (const key of aKeys) {
|
||||
if (aObj[key] !== bObj[key]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Faster JSON comparison for complex nested objects
|
||||
function jsonEqual(a: unknown, b: unknown): boolean {
|
||||
if (a === b) return true;
|
||||
try {
|
||||
return JSON.stringify(a) === JSON.stringify(b);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
class AppStore {
|
||||
// Conversation state
|
||||
conversations = $state<Conversation[]>([]);
|
||||
@@ -327,19 +356,49 @@ class AppStore {
|
||||
isTopologyMinimized = $state(false);
|
||||
isSidebarOpen = $state(false); // Hidden by default, shown when in chat mode
|
||||
debugMode = $state(false);
|
||||
topologyOnlyMode = $state(false);
|
||||
chatSidebarVisible = $state(true); // Shown by default
|
||||
|
||||
// Visibility state - used to pause polling when tab is hidden
|
||||
private isPageVisible = true;
|
||||
|
||||
private fetchInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
|
||||
// Cache for comparison - prevents unnecessary reactivity
|
||||
private lastTopologyJson = '';
|
||||
private lastInstancesJson = '';
|
||||
private lastRunnersJson = '';
|
||||
private lastDownloadsJson = '';
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
this.startPolling();
|
||||
this.loadConversationsFromStorage();
|
||||
this.loadDebugModeFromStorage();
|
||||
this.loadTopologyOnlyModeFromStorage();
|
||||
this.loadChatSidebarVisibleFromStorage();
|
||||
this.setupVisibilityListener();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Listen for page visibility changes to pause polling when hidden
|
||||
*/
|
||||
private setupVisibilityListener() {
|
||||
if (typeof document === 'undefined') return;
|
||||
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
this.isPageVisible = document.visibilityState === 'visible';
|
||||
|
||||
if (this.isPageVisible) {
|
||||
// Resume polling when page becomes visible
|
||||
this.fetchState();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Load conversations from localStorage
|
||||
*/
|
||||
@@ -394,6 +453,44 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
private loadTopologyOnlyModeFromStorage() {
|
||||
try {
|
||||
const stored = localStorage.getItem('exo-topology-only-mode');
|
||||
if (stored !== null) {
|
||||
this.topologyOnlyMode = stored === 'true';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load topology only mode:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private saveTopologyOnlyModeToStorage() {
|
||||
try {
|
||||
localStorage.setItem('exo-topology-only-mode', this.topologyOnlyMode ? 'true' : 'false');
|
||||
} catch (error) {
|
||||
console.error('Failed to save topology only mode:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private loadChatSidebarVisibleFromStorage() {
|
||||
try {
|
||||
const stored = localStorage.getItem('exo-chat-sidebar-visible');
|
||||
if (stored !== null) {
|
||||
this.chatSidebarVisible = stored === 'true';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load chat sidebar visibility:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private saveChatSidebarVisibleToStorage() {
|
||||
try {
|
||||
localStorage.setItem('exo-chat-sidebar-visible', this.chatSidebarVisible ? 'true' : 'false');
|
||||
} catch (error) {
|
||||
console.error('Failed to save chat sidebar visibility:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new conversation
|
||||
*/
|
||||
@@ -698,9 +795,39 @@ class AppStore {
|
||||
this.saveDebugModeToStorage();
|
||||
}
|
||||
|
||||
getTopologyOnlyMode(): boolean {
|
||||
return this.topologyOnlyMode;
|
||||
}
|
||||
|
||||
setTopologyOnlyMode(enabled: boolean) {
|
||||
this.topologyOnlyMode = enabled;
|
||||
this.saveTopologyOnlyModeToStorage();
|
||||
}
|
||||
|
||||
toggleTopologyOnlyMode() {
|
||||
this.topologyOnlyMode = !this.topologyOnlyMode;
|
||||
this.saveTopologyOnlyModeToStorage();
|
||||
}
|
||||
|
||||
getChatSidebarVisible(): boolean {
|
||||
return this.chatSidebarVisible;
|
||||
}
|
||||
|
||||
setChatSidebarVisible(visible: boolean) {
|
||||
this.chatSidebarVisible = visible;
|
||||
this.saveChatSidebarVisibleToStorage();
|
||||
}
|
||||
|
||||
toggleChatSidebarVisible() {
|
||||
this.chatSidebarVisible = !this.chatSidebarVisible;
|
||||
this.saveChatSidebarVisibleToStorage();
|
||||
}
|
||||
|
||||
startPolling() {
|
||||
this.fetchState();
|
||||
this.fetchInterval = setInterval(() => this.fetchState(), 1000);
|
||||
// Poll every 2 seconds instead of 1 second - reduces CPU/GPU load by 50%
|
||||
// Data comparison ensures we only update when something actually changes
|
||||
this.fetchInterval = setInterval(() => this.fetchState(), 2000);
|
||||
}
|
||||
|
||||
stopPolling() {
|
||||
@@ -712,6 +839,9 @@ class AppStore {
|
||||
}
|
||||
|
||||
async fetchState() {
|
||||
// Skip polling when page is hidden to save resources
|
||||
if (!this.isPageVisible) return;
|
||||
|
||||
try {
|
||||
const response = await fetch('/state');
|
||||
if (!response.ok) {
|
||||
@@ -719,19 +849,44 @@ class AppStore {
|
||||
}
|
||||
const data: RawStateResponse = await response.json();
|
||||
|
||||
// Only update topology if it actually changed (prevents unnecessary D3 re-renders)
|
||||
if (data.topology) {
|
||||
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
|
||||
const newTopology = transformTopology(data.topology, data.nodeProfiles);
|
||||
const newTopologyJson = JSON.stringify(newTopology);
|
||||
if (newTopologyJson !== this.lastTopologyJson) {
|
||||
this.lastTopologyJson = newTopologyJson;
|
||||
this.topologyData = newTopology;
|
||||
}
|
||||
}
|
||||
|
||||
// Only update instances if changed
|
||||
if (data.instances) {
|
||||
this.instances = data.instances;
|
||||
this.refreshConversationModelFromInstances();
|
||||
const newInstancesJson = JSON.stringify(data.instances);
|
||||
if (newInstancesJson !== this.lastInstancesJson) {
|
||||
this.lastInstancesJson = newInstancesJson;
|
||||
this.instances = data.instances;
|
||||
this.refreshConversationModelFromInstances();
|
||||
}
|
||||
}
|
||||
|
||||
// Only update runners if changed
|
||||
if (data.runners) {
|
||||
this.runners = data.runners;
|
||||
const newRunnersJson = JSON.stringify(data.runners);
|
||||
if (newRunnersJson !== this.lastRunnersJson) {
|
||||
this.lastRunnersJson = newRunnersJson;
|
||||
this.runners = data.runners;
|
||||
}
|
||||
}
|
||||
|
||||
// Only update downloads if changed
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
const newDownloadsJson = JSON.stringify(data.downloads);
|
||||
if (newDownloadsJson !== this.lastDownloadsJson) {
|
||||
this.lastDownloadsJson = newDownloadsJson;
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
}
|
||||
|
||||
this.lastUpdate = Date.now();
|
||||
} catch (error) {
|
||||
console.error('Error fetching state:', error);
|
||||
@@ -888,8 +1043,6 @@ class AppStore {
|
||||
|
||||
if (lastUserIndex === -1) return;
|
||||
|
||||
const lastUserMessage = this.messages[lastUserIndex];
|
||||
|
||||
// Remove any messages after the user message
|
||||
this.messages = this.messages.slice(0, lastUserIndex + 1);
|
||||
|
||||
@@ -930,7 +1083,10 @@ class AppStore {
|
||||
}
|
||||
|
||||
if (!modelToUse) {
|
||||
assistantMessage.content = 'Error: No model available. Please launch an instance first.';
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = 'Error: No model available. Please launch an instance first.';
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -948,7 +1104,10 @@ class AppStore {
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
assistantMessage.content = `Error: ${response.status} - ${errorText}`;
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = `Error: ${response.status} - ${errorText}`;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -956,7 +1115,10 @@ class AppStore {
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
assistantMessage.content = 'Error: No response stream available';
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = 'Error: No response stream available';
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -984,9 +1146,16 @@ class AppStore {
|
||||
const delta = json.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
fullContent += delta;
|
||||
const { displayContent } = this.stripThinkingTags(fullContent);
|
||||
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
|
||||
this.currentResponse = displayContent;
|
||||
assistantMessage.content = displayContent;
|
||||
|
||||
// Update the assistant message in place (triggers Svelte reactivity)
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
} catch {
|
||||
// Skip malformed JSON
|
||||
@@ -995,16 +1164,25 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
const { displayContent } = this.stripThinkingTags(fullContent);
|
||||
assistantMessage.content = displayContent;
|
||||
this.currentResponse = '';
|
||||
this.updateActiveConversation();
|
||||
// Final cleanup of the message
|
||||
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
|
||||
} catch (error) {
|
||||
assistantMessage.content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
|
||||
this.updateActiveConversation();
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
this.updateActiveConversation();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1364,6 +1542,8 @@ export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
|
||||
// Actions
|
||||
export const startChat = () => appStore.startChat();
|
||||
@@ -1391,5 +1571,9 @@ export const isSidebarOpen = () => appStore.isSidebarOpen;
|
||||
export const toggleSidebar = () => appStore.toggleSidebar();
|
||||
export const toggleDebugMode = () => appStore.toggleDebugMode();
|
||||
export const setDebugMode = (enabled: boolean) => appStore.setDebugMode(enabled);
|
||||
export const toggleTopologyOnlyMode = () => appStore.toggleTopologyOnlyMode();
|
||||
export const setTopologyOnlyMode = (enabled: boolean) => appStore.setTopologyOnlyMode(enabled);
|
||||
export const toggleChatSidebarVisible = () => appStore.toggleChatSidebarVisible();
|
||||
export const setChatSidebarVisible = (visible: boolean) => appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
|
||||
@@ -1,7 +1,25 @@
|
||||
<script lang="ts">
|
||||
import '../app.css';
|
||||
import { onMount } from 'svelte';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
let { children } = $props();
|
||||
let isPageHidden = $state(false);
|
||||
|
||||
onMount(() => {
|
||||
if (!browser) return;
|
||||
|
||||
// Listen for visibility changes to pause animations when hidden
|
||||
const handleVisibilityChange = () => {
|
||||
isPageHidden = document.visibilityState === 'hidden';
|
||||
};
|
||||
|
||||
document.addEventListener('visibilitychange', handleVisibilityChange);
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('visibilitychange', handleVisibilityChange);
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
<svelte:head>
|
||||
@@ -9,7 +27,7 @@
|
||||
<meta name="description" content="EXO - Distributed AI Cluster Dashboard" />
|
||||
</svelte:head>
|
||||
|
||||
<div class="min-h-screen bg-background text-foreground">
|
||||
<div class="min-h-screen bg-background text-foreground" data-page-hidden={isPageHidden}>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
selectedChatModel,
|
||||
debugMode,
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview
|
||||
} from '$lib/stores/app.svelte';
|
||||
@@ -37,6 +41,8 @@
|
||||
const selectedModelId = $derived(selectedPreviewModelId());
|
||||
const loadingPreviews = $derived(isLoadingPreviews());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
@@ -93,17 +99,35 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
|
||||
// Compute highlighted nodes from hovered instance or hovered preview
|
||||
// Memoized to avoid creating new Sets on every render
|
||||
let lastHighlightedNodesKey = '';
|
||||
let cachedHighlightedNodes: Set<string> = new Set();
|
||||
|
||||
const highlightedNodes = $derived(() => {
|
||||
// Create a key for the current state to enable memoization
|
||||
const previewKey = Array.from(hoveredPreviewNodes).sort().join(',');
|
||||
const currentKey = `${hoveredInstanceId || 'null'}:${previewKey}`;
|
||||
|
||||
// Return cached value if nothing changed
|
||||
if (currentKey === lastHighlightedNodesKey) {
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
|
||||
lastHighlightedNodesKey = currentKey;
|
||||
|
||||
// First check instance hover
|
||||
if (hoveredInstanceId) {
|
||||
const instanceWrapped = instanceData[hoveredInstanceId];
|
||||
return unwrapInstanceNodes(instanceWrapped);
|
||||
cachedHighlightedNodes = unwrapInstanceNodes(instanceWrapped);
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
// Then check preview hover
|
||||
if (hoveredPreviewNodes.size > 0) {
|
||||
return hoveredPreviewNodes;
|
||||
cachedHighlightedNodes = hoveredPreviewNodes;
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
return new Set<string>();
|
||||
cachedHighlightedNodes = new Set<string>();
|
||||
return cachedHighlightedNodes;
|
||||
});
|
||||
|
||||
// Helper to estimate memory from model ID (mirrors ModelCard logic)
|
||||
@@ -472,6 +496,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -489,13 +514,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -505,12 +534,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
};
|
||||
}
|
||||
|
||||
// Debug: Log downloads data when it changes
|
||||
$effect(() => {
|
||||
if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
console.log('[Download Debug] Current downloads:', downloadsData);
|
||||
}
|
||||
});
|
||||
// Debug: Log downloads data when it changes (disabled in production for performance)
|
||||
// Uncomment for debugging:
|
||||
// $effect(() => {
|
||||
// if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
// console.log('[Download Debug] Current downloads:', downloadsData);
|
||||
// }
|
||||
// });
|
||||
|
||||
// Helper to get download status for an instance
|
||||
function getInstanceDownloadStatus(instanceId: string, instanceWrapped: unknown): {
|
||||
@@ -576,6 +606,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -596,13 +627,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, statusText: statusInfo.statusText, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -618,10 +653,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
function getStatusColor(statusText: string): string {
|
||||
switch (statusText) {
|
||||
case 'FAILED': return 'text-red-400';
|
||||
case 'SHUTDOWN': return 'text-gray-400';
|
||||
case 'DOWNLOADING': return 'text-blue-400';
|
||||
case 'LOADING':
|
||||
case 'WARMING UP':
|
||||
case 'WAITING': return 'text-yellow-400';
|
||||
case 'WAITING':
|
||||
case 'INITIALIZING': return 'text-yellow-400';
|
||||
case 'RUNNING': return 'text-teal-400';
|
||||
case 'READY':
|
||||
case 'LOADED': return 'text-green-400';
|
||||
@@ -644,12 +681,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
if (!r) return null;
|
||||
const [kind] = getTagged(r);
|
||||
const statusMap: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: 'WaitingForInitialization',
|
||||
RunnerInitializingBackend: 'InitializingBackend',
|
||||
RunnerWaitingForModel: 'WaitingForModel',
|
||||
RunnerLoading: 'Loading',
|
||||
RunnerLoaded: 'Loaded',
|
||||
RunnerWarmingUp: 'WarmingUp',
|
||||
RunnerReady: 'Ready',
|
||||
RunnerRunning: 'Running',
|
||||
RunnerShutdown: 'Shutdown',
|
||||
RunnerFailed: 'Failed',
|
||||
};
|
||||
return kind ? statusMap[kind] || null : null;
|
||||
@@ -660,12 +700,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
|
||||
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
|
||||
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
|
||||
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
|
||||
if (has('WarmingUp')) return { statusText: 'WARMING UP', statusClass: 'starting' };
|
||||
if (has('Running')) return { statusText: 'RUNNING', statusClass: 'running' };
|
||||
if (has('Ready')) return { statusText: 'READY', statusClass: 'loaded' };
|
||||
if (has('Loaded')) return { statusText: 'LOADED', statusClass: 'loaded' };
|
||||
if (has('WaitingForModel')) return { statusText: 'WAITING', statusClass: 'starting' };
|
||||
if (has('InitializingBackend')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
if (has('WaitingForInitialization')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
|
||||
return { statusText: 'RUNNING', statusClass: 'active' };
|
||||
}
|
||||
@@ -1107,16 +1150,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="shooting-star" style="top: 50%; left: 40%; --duration: 45s; --delay: 30s;"></div>
|
||||
</div>
|
||||
|
||||
<HeaderNav showHome={chatStarted} onHome={handleGoHome} />
|
||||
{#if !topologyOnlyEnabled}
|
||||
<HeaderNav
|
||||
showHome={chatStarted}
|
||||
onHome={handleGoHome}
|
||||
showSidebarToggle={true}
|
||||
sidebarVisible={sidebarVisible}
|
||||
onToggleSidebar={toggleChatSidebarVisible}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="flex-1 flex overflow-hidden relative">
|
||||
<!-- Left: Conversation History Sidebar (always visible) -->
|
||||
<!-- Left: Conversation History Sidebar (hidden in topology-only mode or when toggled off) -->
|
||||
{#if !topologyOnlyEnabled && sidebarVisible}
|
||||
<div class="w-80 flex-shrink-0 border-r border-exo-yellow/10">
|
||||
<ChatSidebar class="h-full" />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if !chatStarted}
|
||||
{#if topologyOnlyEnabled}
|
||||
<!-- TOPOLOGY ONLY MODE: Full-screen topology -->
|
||||
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
|
||||
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
|
||||
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="absolute bottom-4 right-4 p-2 rounded border border-exo-yellow/30 bg-exo-dark-gray/80 hover:border-exo-yellow/50 hover:bg-exo-dark-gray transition-colors cursor-pointer backdrop-blur-sm"
|
||||
title="Exit topology only mode"
|
||||
>
|
||||
<svg class="w-5 h-5 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if !chatStarted}
|
||||
<!-- WELCOME STATE: Topology + Instance Controls (no left sidebar for cleaner look) -->
|
||||
<div class="flex-1 flex overflow-visible relative" in:fade={{ duration: 300 }} out:fade={{ duration: 200 }}>
|
||||
|
||||
@@ -1300,14 +1374,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/80">{filePercent.toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
@@ -1611,13 +1686,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
in:fade={{ duration: 300, delay: 100 }}
|
||||
>
|
||||
<div class="flex-1 overflow-y-auto px-8 py-6" bind:this={chatScrollRef}>
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatMessages scrollParent={chatScrollRef} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
</div>
|
||||
</div>
|
||||
@@ -1655,7 +1730,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<!-- Panel Header -->
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div class="w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse"></div>
|
||||
<h3 class="text-sm text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
|
||||
@@ -1701,28 +1776,28 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="flex justify-between items-start mb-2 pl-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-xs tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400/80 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-sm font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[10px] text-white/60 hover:text-exo-yellow transition-colors mt-0.5"
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
href={`https://huggingface.co/${instanceModelId}`}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
aria-label="View model on Hugging Face"
|
||||
>
|
||||
<span>Hugging Face</span>
|
||||
<svg class="w-3 h-3" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M14 3h7v7"/>
|
||||
<path d="M10 14l11-11"/>
|
||||
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6"/>
|
||||
@@ -1733,68 +1808,84 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-white/60 text-xs font-mono">{instanceInfo.nodeNames.join(', ')}</div>
|
||||
{/if}
|
||||
{#if debugEnabled && instanceConnections.length > 0}
|
||||
<div class="mt-1 space-y-0.5">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[10px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
<div class="mt-2 space-y-1">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[11px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-xs font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-sm font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-1.5 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-2 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
{@const nodePercent = Math.min(100, Math.max(0, nodeProg.progress.percentage))}
|
||||
{@const isExpanded = instanceDownloadExpandedNodes.has(nodeProg.nodeId)}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<button
|
||||
type="button"
|
||||
class="w-full text-left space-y-1.5"
|
||||
onclick={() => toggleInstanceDownloadDetails(nodeProg.nodeId)}
|
||||
>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span class="text-white/80 truncate pr-2">{nodeProg.nodeName}</span>
|
||||
<span class="text-blue-300">{Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%</span>
|
||||
<span class="flex items-center gap-1 text-blue-300">
|
||||
{nodePercent.toFixed(1)}%
|
||||
<svg class="w-3 h-3 text-exo-light-gray" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path d="M6 8l4 4 4-4" class={isExpanded ? 'transform rotate-180 origin-center transition-transform duration-150' : 'transition-transform duration-150'}></path>
|
||||
</svg>
|
||||
</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mb-1.5">
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-blue-500/80 transition-all duration-300"
|
||||
style="width: {Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {nodePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span>{formatBytes(nodeProg.progress.downloadedBytes)} / {formatBytes(nodeProg.progress.totalBytes)}</span>
|
||||
<span>{formatSpeed(nodeProg.progress.speed)} • ETA {formatEta(nodeProg.progress.etaMs)}</span>
|
||||
</div>
|
||||
{#if nodeProg.progress.files.length > 0}
|
||||
{@const inProgressFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) < 100)}
|
||||
{@const completedFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) >= 100)}
|
||||
{#if inProgressFiles.length > 0}
|
||||
<div class="space-y-1">
|
||||
{#each inProgressFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/80">
|
||||
<div class="flex items-center justify-between">
|
||||
</button>
|
||||
|
||||
{#if isExpanded}
|
||||
<div class="mt-2 space-y-1.5">
|
||||
{#if nodeProg.progress.files.length === 0}
|
||||
<div class="text-[11px] font-mono text-exo-light-gray/70">No file details reported.</div>
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/70">{Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/50 rounded-sm overflow-hidden mt-0.5">
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70"
|
||||
style="width: {Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5">
|
||||
@@ -1803,27 +1894,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{#if completedFiles.length > 0}
|
||||
<div class="pt-1 space-y-0.5">
|
||||
{#each completedFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/70 flex items-center justify-between">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/60">100%</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-sm {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-xs text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-xs {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -345,13 +345,19 @@
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-dark-gray/60 p-3 space-y-2">
|
||||
<div class="flex items-center justify-between gap-3">
|
||||
<div class="min-w-0 space-y-0.5">
|
||||
<div class="text-sm font-mono text-white truncate">{model.prettyName ?? model.modelId}</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono truncate">
|
||||
{model.modelId}
|
||||
</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
<div
|
||||
class="text-xs font-mono text-white truncate"
|
||||
title={model.prettyName ?? model.modelId}
|
||||
>{model.prettyName ?? model.modelId}</div>
|
||||
<div
|
||||
class="text-[10px] text-exo-light-gray font-mono truncate"
|
||||
title={model.modelId}
|
||||
>{model.modelId}</div>
|
||||
{#if model.status !== 'completed'}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
|
||||
@@ -426,14 +432,14 @@
|
||||
<style>
|
||||
.downloads-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
|
||||
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
|
||||
}
|
||||
@media (min-width: 1024px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(3, minmax(0, 1fr));
|
||||
}
|
||||
}
|
||||
@media (min-width: 1440px) {
|
||||
@media (min-width: 1600px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(4, minmax(0, 1fr));
|
||||
}
|
||||
|
||||
@@ -29,10 +29,12 @@ dependencies = [
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.30.1",
|
||||
"mlx>=0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm>=0.28.3",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -13,6 +13,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
@@ -21,11 +27,13 @@ from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -56,7 +64,7 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
HIDE_THINKING = False
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
@@ -161,7 +169,9 @@ class API:
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions")(self.chat_completions)
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -177,17 +187,32 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=command.model_meta,
|
||||
)
|
||||
|
||||
async def create_instance(
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
command = CreateInstance(instance=payload.instance)
|
||||
instance = payload.instance
|
||||
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
|
||||
required_memory = model_meta.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
if required_memory > available_memory:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB",
|
||||
)
|
||||
|
||||
command = CreateInstance(
|
||||
instance=instance,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
async def get_placement(
|
||||
@@ -352,32 +377,52 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
is_thinking = False
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
if HIDE_THINKING:
|
||||
if chunk.text == "<think>":
|
||||
is_thinking = True
|
||||
if chunk.text == "</think>":
|
||||
is_thinking = False
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
if not (is_thinking and HIDE_THINKING):
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -392,6 +437,59 @@ class API:
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
@@ -399,10 +497,12 @@ class API:
|
||||
|
||||
async def chat_completions(
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> StreamingResponse:
|
||||
"""Handle chat completions with proper streaming response."""
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -417,10 +517,13 @@ class API:
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
@@ -442,6 +545,8 @@ class API:
|
||||
name=card.name,
|
||||
description=card.description,
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -458,7 +563,7 @@ class API:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._applystate)
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
@@ -470,7 +575,7 @@ class API:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _applystate(self):
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
|
||||
@@ -7,9 +7,9 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -19,7 +19,6 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
@@ -130,17 +129,17 @@ def place_instance(
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
hosts_by_node = get_mlx_ring_hosts_by_node(
|
||||
selected_cycle=selected_cycle,
|
||||
cycle_digraph=cycle_digraph,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
target_instances[instance_id] = MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[
|
||||
Host(
|
||||
ip=host.ip,
|
||||
port=random_ephemeral_port(),
|
||||
)
|
||||
for host in hosts
|
||||
],
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -215,9 +215,11 @@ def get_mlx_ibv_devices_matrix(
|
||||
continue
|
||||
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
|
||||
if interface_name := _find_rdma_interface_name_for_ip(
|
||||
connection_ip, node_i
|
||||
):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
@@ -238,17 +240,17 @@ def _find_connection_ip(
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[str]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
def _find_rdma_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
@@ -269,6 +271,109 @@ def _find_interface_name_for_ip(
|
||||
return None
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = iface_map.get("en1")
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
"""Generate per-node host lists for MLX ring backend.
|
||||
|
||||
Each node gets a list where:
|
||||
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
|
||||
- Left/right neighbors: actual connection IPs
|
||||
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
|
||||
"""
|
||||
world_size = len(selected_cycle)
|
||||
if world_size == 0:
|
||||
return {}
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
for rank, node in enumerate(selected_cycle):
|
||||
node_id = node.node_id
|
||||
left_rank = (rank - 1) % world_size
|
||||
right_rank = (rank + 1) % world_size
|
||||
|
||||
hosts_for_node: list[Host] = []
|
||||
|
||||
for idx, other_node in enumerate(selected_cycle):
|
||||
if idx == rank:
|
||||
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
continue
|
||||
|
||||
if idx not in {left_rank, right_rank}:
|
||||
# Placeholder IP from RFC 5737 TEST-NET-2
|
||||
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
|
||||
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
|
||||
|
||||
hosts_by_node[node_id] = hosts_for_node
|
||||
|
||||
return hosts_by_node
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
@@ -286,7 +391,7 @@ def get_mlx_jaccl_coordinators(
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
|
||||
@@ -123,6 +123,8 @@ async def test_master():
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
@@ -163,32 +165,38 @@ async def test_master():
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event.instance == MlxRingInstance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
|
||||
# Validate the shard assignments
|
||||
expected_shard_assignments = ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
)
|
||||
assert created_instance.shard_assignments == expected_shard_assignments
|
||||
# For single-node, hosts_by_node should have one entry with self-binding
|
||||
assert len(created_instance.hosts_by_node) == 1
|
||||
assert node_id in created_instance.hosts_by_node
|
||||
assert len(created_instance.hosts_by_node[node_id]) == 1
|
||||
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
|
||||
@@ -38,7 +38,8 @@ def instance() -> Instance:
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +50,8 @@ def model_meta() -> ModelMetadata:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=10,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,9 +95,13 @@ def test_get_instance_placements_create_instance(
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -135,6 +142,8 @@ def test_get_instance_placements_one_node_exact_fit(
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -160,6 +169,8 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -185,6 +196,8 @@ def test_get_instance_placements_one_node_not_fit(
|
||||
storage_size=Memory.from_kb(1001),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -234,17 +247,15 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
|
||||
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
|
||||
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
|
||||
# cycle having LESS total memory than D-E-F. The algorithm should still choose
|
||||
# the cycle that contains a leaf node.
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
@@ -258,11 +269,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
|
||||
node_id_x = NodeId()
|
||||
node_id_y = NodeId()
|
||||
node_id_z = NodeId()
|
||||
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
@@ -273,24 +279,20 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
|
||||
# Extra nodes with tiny memory so they can't form singleton placements
|
||||
topology.add_node(create_node(10, node_id_x))
|
||||
topology.add_node(create_node(10, node_id_y))
|
||||
topology.add_node(create_node(10, node_id_z))
|
||||
|
||||
# Build directed cycles
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
|
||||
# Add extra outgoing edges from D/E/F so none of them are leaves
|
||||
topology.add_connection(create_connection(node_id_d, node_id_x))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_y))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_z))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
@@ -299,18 +301,17 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
|
||||
# D-E-F has more total memory.
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
@@ -198,6 +198,8 @@ def test_get_shard_assignments(
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
@@ -51,6 +51,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
@@ -64,6 +66,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
@@ -135,6 +139,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
@@ -148,6 +154,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Kimi K2 Thinking (4-bit)",
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
@@ -162,6 +170,38 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.1 8B (4-bit)",
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
short_id="llama-3.1-8b-8bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.1 8B (8-bit)",
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
short_id="llama-3.1-8b-bf16",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
pretty_name="Llama 3.1 8B (BF16)",
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
@@ -175,6 +215,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.1 70B (4-bit)",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.2
|
||||
@@ -189,6 +231,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 1B (4-bit)",
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
@@ -202,6 +246,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 3B (4-bit)",
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
@@ -215,6 +261,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 3B (8-bit)",
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.3
|
||||
@@ -229,6 +277,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
@@ -242,6 +292,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B (8-bit)",
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
@@ -255,20 +307,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B (FP16)",
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
# phi-3
|
||||
"phi-3-mini": ModelCard(
|
||||
short_id="phi-3-mini",
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
name="Phi 3 Mini 128k (4-bit)",
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
pretty_name="Phi 3 Mini 128k (4-bit)",
|
||||
storage_size=Memory.from_mb(2099),
|
||||
n_layers=32,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# qwen3
|
||||
@@ -283,6 +323,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 0.6B (4-bit)",
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
@@ -296,6 +338,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 0.6B (8-bit)",
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
@@ -309,6 +353,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 30B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
@@ -322,6 +368,68 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 30B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
@@ -335,6 +443,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 235B A22B (4-bit)",
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
@@ -348,6 +458,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 235B A22B (8-bit)",
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
@@ -361,6 +473,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
@@ -374,77 +488,84 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# granite
|
||||
"granite-3.3-2b": ModelCard(
|
||||
short_id="granite-3.3-2b",
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
name="Granite 3.3 2B (FP16)",
|
||||
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
pretty_name="Granite 3.3 2B (FP16)",
|
||||
storage_size=Memory.from_mb(4951),
|
||||
n_layers=40,
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "granite-3.3-8b": ModelCard(
|
||||
# short_id="granite-3.3-8b",
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# name="Granite 3.3 8B",
|
||||
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# Needs to be quantized g32 or g16.
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
pretty_name="GLM 4.5 Air 8bit",
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
short_id="glm-4.5-air-bf16",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
pretty_name="GLM 4.5 Air bf16",
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# pretty_name="Granite 3.3 8B",
|
||||
# storage_size=Memory.from_kb(15958720),
|
||||
# n_layers=40,
|
||||
# ),
|
||||
# ),
|
||||
# smol-lm
|
||||
# "smol-lm-135m": ModelCard(
|
||||
# short_id="smol-lm-135m",
|
||||
# model_id="mlx-community/SmolLM-135M-4bit",
|
||||
# name="Smol LM 135M",
|
||||
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
|
||||
# pretty_name="Smol LM 135M",
|
||||
# storage_size=Memory.from_kb(73940),
|
||||
# n_layers=30,
|
||||
# ),
|
||||
# ),
|
||||
# gpt-oss
|
||||
# "gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
# short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# storage_size=Memory.from_kb(68_996_301),
|
||||
# n_layers=36,
|
||||
# hidden_size=2880,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# "gpt-oss-20b-4bit": ModelCard(
|
||||
# short_id="gpt-oss-20b-4bit",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# storage_size=Memory.from_kb(11_744_051),
|
||||
# n_layers=24,
|
||||
# hidden_size=2880,
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# storage_size=Memory.from_kb(133_000_000),
|
||||
# n_layers=88,
|
||||
# hidden_size=12288,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
|
||||
@@ -6,6 +6,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
@@ -25,6 +26,7 @@ class ConfigData(BaseModel):
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
@@ -106,10 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
model_card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_id,
|
||||
pretty_name=model_card.name if model_card is not None else model_id,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.metadata.supports_tensor
|
||||
if model_card is not None
|
||||
else False,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,8 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event), state
|
||||
)
|
||||
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event]})
|
||||
assert new_state.downloads == {NodeId("node-1"): [event]}
|
||||
|
||||
|
||||
def test_apply_two_node_download_progress():
|
||||
@@ -42,4 +42,4 @@ def test_apply_two_node_download_progress():
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
@@ -174,6 +174,7 @@ class DeleteInstanceTaskParams(BaseModel):
|
||||
class CreateInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
model_meta: ModelMetadata
|
||||
|
||||
|
||||
class DeleteInstanceResponse(BaseModel):
|
||||
|
||||
@@ -14,3 +14,5 @@ class ModelMetadata(CamelCaseModel):
|
||||
pretty_name: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
|
||||
@@ -40,6 +40,10 @@ class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class StartWarmup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -57,5 +61,11 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -25,7 +25,8 @@ class BaseInstance(TaggedModel):
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts: list[Host]
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
|
||||
@@ -21,7 +21,15 @@ class BaseRunnerStatus(TaggedModel):
|
||||
return isinstance(self, RunnerRunning)
|
||||
|
||||
|
||||
class RunnerWaitingForModel(BaseRunnerStatus):
|
||||
class RunnerIdle(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnecting(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnected(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
@@ -54,7 +62,9 @@ class RunnerFailed(BaseRunnerStatus):
|
||||
|
||||
|
||||
RunnerStatus = (
|
||||
RunnerWaitingForModel
|
||||
RunnerIdle
|
||||
| RunnerConnecting
|
||||
| RunnerConnected
|
||||
| RunnerLoading
|
||||
| RunnerLoaded
|
||||
| RunnerWarmingUp
|
||||
|
||||
@@ -95,7 +95,15 @@ def extract_layer_num(tensor_name: str) -> int | None:
|
||||
|
||||
def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list[str]:
|
||||
default_patterns = set(
|
||||
["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt", "*.jinja"]
|
||||
[
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"tiktoken.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jinja",
|
||||
]
|
||||
)
|
||||
shard_specific_patterns: set[str] = set()
|
||||
if weight_map:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import copy
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
@@ -12,7 +13,7 @@ from exo.shared.types.worker.shards import (
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
class ShardDownloader(ABC):
|
||||
@abstractmethod
|
||||
async def ensure_shard(
|
||||
@@ -43,34 +44,7 @@ class ShardDownloader(ABC):
|
||||
Yields:
|
||||
tuple[Path, RepoDownloadProgress]: The path and progress of a shard download.
|
||||
"""
|
||||
yield (
|
||||
Path("/tmp/noop_shard"),
|
||||
RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
),
|
||||
)
|
||||
yield (Path("/tmp/noop_shard"), NOOP_DOWNLOAD_PROGRESS)
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_download_status_for_shard(
|
||||
@@ -94,46 +68,41 @@ class NoopShardDownloader(ShardDownloader):
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
yield (
|
||||
Path("/tmp/noop_shard"),
|
||||
RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
),
|
||||
NOOP_DOWNLOAD_PROGRESS,
|
||||
)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
return RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=shard,
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
dp = copy(NOOP_DOWNLOAD_PROGRESS)
|
||||
dp.shard = shard
|
||||
return dp
|
||||
|
||||
|
||||
NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = 8
|
||||
TEMPERATURE: float = 1.0
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -13,7 +13,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
TEMPERATURE,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -21,6 +20,8 @@ try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
import contextlib
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
@@ -48,6 +49,7 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
# Needed for 8 bit model
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
@@ -67,7 +69,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
@@ -77,7 +79,7 @@ def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
@@ -99,91 +101,97 @@ class HostList(RootModel[list[str]]):
|
||||
|
||||
def mlx_distributed_init(
|
||||
bound_instance: BoundInstance,
|
||||
) -> mx.distributed.Group:
|
||||
) -> Group:
|
||||
"""
|
||||
Initialize the MLX distributed (runs in thread pool).
|
||||
|
||||
Either hosts or mlx_ibv_devices must be provided:
|
||||
- hosts: traditional host-based connectivity using MLX_HOSTFILE
|
||||
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
|
||||
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
|
||||
- strict: if True, raise an error if the distributed backend is not available
|
||||
Initialize MLX distributed.
|
||||
"""
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts=hosts):
|
||||
hostfile = f"./hosts_{rank}.json"
|
||||
hosts_json = HostList.from_hosts(hosts).model_dump_json()
|
||||
coordination_file = None
|
||||
try:
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
with open(hostfile, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
|
||||
logger.info(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
|
||||
logger.info(
|
||||
f"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}"
|
||||
)
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
os.environ["MLX_HOSTFILE"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
|
||||
with open(devices_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
bound_instance: BoundInstance,
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
"""
|
||||
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
|
||||
"""
|
||||
) -> Group:
|
||||
# should we unseed it?
|
||||
# TODO: pass in seed from params
|
||||
mx.random.seed(42)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (
|
||||
"Tried to initialize mlx for a single node instance"
|
||||
)
|
||||
return mlx_distributed_init(bound_instance)
|
||||
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
# TODO: pass temperature
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
logger.info("Created a sampler")
|
||||
|
||||
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
@@ -193,14 +201,12 @@ def initialize_mlx(
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
logger.debug(model)
|
||||
|
||||
return cast(Model, model), tokenizer, sampler
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: mx.distributed.Group,
|
||||
group: Group,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ class Worker:
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.event_sender.send_nowait(
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
@@ -414,9 +414,14 @@ class Worker:
|
||||
while True:
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(self.state.topology)
|
||||
conns = await check_reachable(self.state.topology, self.node_id)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=nid,
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Mapping, Sequence
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
LoadModel,
|
||||
@@ -14,17 +15,23 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
@@ -48,6 +55,7 @@ def plan(
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
@@ -106,9 +114,11 @@ def _model_needs_download(
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
if (
|
||||
isinstance(runner.status, RunnerWaitingForModel)
|
||||
and runner.bound_instance.bound_shard not in download_status
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
not isinstance(
|
||||
download_status.get(runner.bound_instance.bound_shard, None),
|
||||
(DownloadOngoing, DownloadCompleted),
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
@@ -117,14 +127,54 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
""" --- TODO!
|
||||
def _init_backend(
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> LoadModel | None:
|
||||
for runner in runner.values()
|
||||
pass
|
||||
"""
|
||||
):
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance:
|
||||
continue
|
||||
|
||||
runner_is_idle = isinstance(runner.status, RunnerIdle)
|
||||
all_runners_connecting = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerConnecting, RunnerIdle),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if not (runner_is_idle and all_runners_connecting):
|
||||
continue
|
||||
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
world_size = shard.world_size
|
||||
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
accepting_ranks = device_rank < world_size - 1
|
||||
|
||||
# Rank = n-1
|
||||
connecting_rank_ready = device_rank == world_size - 1 and all(
|
||||
isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner_id
|
||||
)
|
||||
|
||||
if not (accepting_ranks or connecting_rank_ready):
|
||||
continue
|
||||
|
||||
return ConnectToGroup(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_model(
|
||||
@@ -136,31 +186,33 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
all_downloads_complete_local = all(
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
|
||||
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid, rid in shard_assignments.node_to_runner.items()
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
all_runners_expecting_model = all(
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if (
|
||||
all_downloads_complete_local
|
||||
and runner_is_waiting
|
||||
and all_runners_expecting_model
|
||||
):
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
@@ -183,8 +235,9 @@ def _ready_to_warmup(
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
# Rank != n-1
|
||||
accepting_ranks_ready = device_rank != world_size - 1 and all(
|
||||
# TODO: Ensure these align with MLX distributeds expectations.
|
||||
# Rank < n-1
|
||||
accepting_ranks_ready = device_rank < world_size - 1 and all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerLoaded, RunnerWarmingUp),
|
||||
@@ -221,6 +274,8 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
|
||||
@@ -2,16 +2,13 @@ import os
|
||||
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
|
||||
logger: "loguru.Logger"
|
||||
|
||||
|
||||
if os.getenv("EXO_TESTS") == "1":
|
||||
logger = loguru.logger
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
|
||||
def entrypoint(
|
||||
@@ -30,6 +27,23 @@ def entrypoint(
|
||||
logger = _logger
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
from exo.worker.runner.runner import main
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=str(e)),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -22,20 +23,23 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -63,9 +67,10 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
sampler = None
|
||||
group = None
|
||||
|
||||
current_status: RunnerStatus = RunnerWaitingForModel()
|
||||
logger.info("runner waiting for model")
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
@@ -78,9 +83,26 @@ def main(
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case LoadModel() if isinstance(
|
||||
current_status, (RunnerWaitingForModel, RunnerFailed)
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected)
|
||||
and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
@@ -89,15 +111,12 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, sampler = initialize_mlx(bound_instance)
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
@@ -123,11 +142,6 @@ def main(
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
@@ -172,11 +186,6 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case Shutdown():
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
@@ -186,12 +195,19 @@ def main(
|
||||
)
|
||||
break
|
||||
case _:
|
||||
raise ValueError("Received task outside of state machine")
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
|
||||
@@ -19,8 +19,8 @@ from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
@@ -41,7 +41,7 @@ class RunnerSupervisor:
|
||||
_event_sender: Sender[Event]
|
||||
# err_path: str
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerWaitingForModel, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -24,3 +24,9 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
|
||||
|
||||
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
|
||||
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
|
||||
|
||||
SHUTDOWN_TASK_ID = TaskId("shutdown")
|
||||
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
|
||||
INITIALIZATION_TASK_ID = TaskId("initialisation")
|
||||
LOAD_TASK_ID = TaskId("load")
|
||||
WARMUP_TASK_ID = TaskId("warmup")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
@@ -14,6 +16,7 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
|
||||
# Runner supervisor without multiprocessing logic.
|
||||
@dataclass(frozen=True)
|
||||
class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
@@ -35,6 +38,8 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=2048,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
@@ -67,5 +72,27 @@ def get_mlx_ring_instance(
|
||||
shard_assignments=get_shard_assignments(
|
||||
model_id, node_to_runner, runner_to_shard
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
def get_bound_mlx_ring_instance(
|
||||
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
|
||||
) -> BoundInstance:
|
||||
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=2)
|
||||
other_shard = get_pipeline_shard_metadata(
|
||||
model_id=model_id, device_rank=1, world_size=2
|
||||
)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=instance_id,
|
||||
model_id=model_id,
|
||||
node_to_runner={
|
||||
node_id: runner_id,
|
||||
NodeId("other_node"): RunnerId("other_runner"),
|
||||
},
|
||||
runner_to_shard={runner_id: shard, RunnerId("other_runner"): other_shard},
|
||||
)
|
||||
return BoundInstance(
|
||||
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
|
||||
)
|
||||
|
||||
@@ -4,7 +4,8 @@ from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerWaitingForModel,
|
||||
RunnerConnected,
|
||||
RunnerIdle,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
@@ -38,13 +39,11 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
@@ -82,15 +81,15 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
@@ -133,13 +132,11 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
@@ -183,14 +180,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
@@ -213,3 +210,22 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
@@ -5,9 +5,9 @@ from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
COMMAND_1_ID,
|
||||
@@ -99,7 +99,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerReady(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerIdle(),
|
||||
}
|
||||
|
||||
task = ChatCompletion(
|
||||
|
||||
@@ -2,8 +2,9 @@ import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.tasks import StartWarmup
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerWaitingForModel,
|
||||
RunnerLoading,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
@@ -21,9 +22,9 @@ from exo.worker.tests.unittests.conftest import (
|
||||
)
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
"""
|
||||
For non-zero device_rank shards, StartWarmup should be emitted when all
|
||||
For non-final device_rank shards, StartWarmup should be emitted when all
|
||||
shards in the instance are Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
@@ -36,13 +37,13 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
@@ -50,10 +51,10 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
@@ -128,7 +129,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerIdle(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
@@ -149,6 +150,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
"""
|
||||
Rank-zero shard should not start warmup until all non-zero ranks are
|
||||
already WarmingUp.
|
||||
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
|
||||
emitted when all shards in the instance are Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -159,6 +163,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
@@ -173,6 +178,93 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
|
||||
"""
|
||||
For connecting rank (device_rank == world_size - 1), StartWarmup should
|
||||
only be emitted once all the other runners are already warming up.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWarmingUp(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():
|
||||
"""
|
||||
Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoading(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
@@ -184,3 +276,46 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
"""
|
||||
Connecting rank (device_rank == world_size - 1) should not start warmup
|
||||
until all other ranks are already WarmingUp.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
COMMAND_1_ID,
|
||||
INITIALIZATION_TASK_ID,
|
||||
INSTANCE_1_ID,
|
||||
LOAD_TASK_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
SHUTDOWN_TASK_ID,
|
||||
WARMUP_TASK_ID,
|
||||
)
|
||||
from ..conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T) -> Callable[[], T]:
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
nothin = make_nothin(None)
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(
|
||||
task_id=INITIALIZATION_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
LOAD_TASK = LoadModel(
|
||||
task_id=LOAD_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
WARMUP_TASK = StartWarmup(
|
||||
task_id=WARMUP_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
SHUTDOWN_TASK = Shutdown(
|
||||
task_id=SHUTDOWN_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
CHAT_PARAMS = ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=4,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
CHAT_TASK = ChatCompletion(
|
||||
task_id=CHAT_COMPLETION_TASK_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=CHAT_PARAMS,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
|
||||
for test_event, true_event in zip(test_events, true_events, strict=True):
|
||||
test_event.event_id = true_event.event_id
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NODE_A,
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
with task_sender, event_receiver:
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
|
||||
# worst monkeypatch known to man
|
||||
# this is some c++ nonsense
|
||||
event_sender.close = nothin
|
||||
event_sender.join = nothin
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||
|
||||
expected_chunk = ChunkGenerated(
|
||||
command_id=COMMAND_1_ID,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
assert_events_equal(
|
||||
events,
|
||||
[
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
# TODO:
|
||||
@@ -1,33 +1,64 @@
|
||||
import socket
|
||||
import http.client
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
|
||||
# TODO: ref. api port
|
||||
async def check_reachability(
|
||||
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
) -> None:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1) # 1 second timeout
|
||||
try:
|
||||
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
|
||||
except socket.gaierror:
|
||||
# seems to throw on ipv6 loopback. oh well
|
||||
# logger.warning(f"invalid {target_ip=}")
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
if target_node_id not in out:
|
||||
out[target_node_id] = set()
|
||||
out[target_node_id].add(target_ip)
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
f"ip={target_ip}, expected_node_id={expected_node_id}, "
|
||||
f"remote_node_id={remote_node_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if remote_node_id not in out:
|
||||
out[remote_node_id] = set()
|
||||
out[remote_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
|
||||
async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
@@ -35,7 +66,11 @@ async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability, iface.ip_address, node.node_id, reachable
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
)
|
||||
|
||||
return reachable
|
||||
|
||||
41
uv.lock
generated
41
uv.lock
generated
@@ -334,8 +334,10 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -374,9 +376,11 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mlx", specifier = ">=0.30.1" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = ">=0.30.1" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.30.1" },
|
||||
{ name = "mlx-lm", specifier = ">=0.28.3" },
|
||||
{ name = "networkx", specifier = ">=3.5" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "protobuf", specifier = ">=6.32.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
{ name = "pydantic", specifier = ">=2.11.7" },
|
||||
@@ -801,6 +805,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
cpu = [
|
||||
{ name = "mlx-cpu", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/64/51/32903727a68a61e972383e28a775c1f5e5f0628552c85cbc6103d68c0dc4/mlx_cpu-0.30.1-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:3f5dc2e4d0849181f8253508bb6a0854250483fc63d43ac79ec614b19824b172", size = 8992394, upload-time = "2025-12-18T00:16:13.696Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/74/69c21bb907f3c4064881ab0653029c939ae15fc4e63a5301ef8643cb1d68/mlx_cpu-0.30.1-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:c9ea6992d8c001e1123dfd3b4d4405ff576c787eec52656ad405e3d033a8be60", size = 10553055, upload-time = "2025-12-18T00:16:16.104Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.28.3"
|
||||
@@ -946,6 +964,27 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/c1/6dba12fdf68b02a21ac411c9df19afa66bed2540f467150ca64d246b463d/numpy-2.3.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e1708fac43ef8b419c975926ce1eaf793b0c13b7356cfab6ab0dc34c0a02ac0f", size = 18652691, upload-time = "2025-10-15T16:17:46.247Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openai-harmony"
|
||||
version = "0.0.8"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3e/92/2d038d096f29179c7c9571b431f9e739f87a487121901725e23fe338dd9d/openai_harmony-0.0.8.tar.gz", hash = "sha256:6e43f98e6c242fa2de6f8ea12eab24af63fa2ed3e89c06341fb9d92632c5cbdf", size = 284777, upload-time = "2025-11-05T19:07:06.727Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/45/c6/2502f416d46be3ec08bb66d696cccffb57781a499e3ff2e4d7c174af4e8f/openai_harmony-0.0.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:029ec25ca74abe48fdb58eb9fdd2a8c1618581fc33ce8e5653f8a1ffbfbd9326", size = 2627806, upload-time = "2025-11-05T19:06:57.063Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/d2/ce6953ca87db9cae3e775024184da7d1c5cb88cead19a2d75b42f00a959c/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4f709815924ec325b9a890e6ab2bbb0ceec8e319a4e257328eb752cf36b2efc", size = 2948463, upload-time = "2025-11-05T19:06:48.17Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/4c/b553c9651662d6ce102ca7f3629d268b23df1abe5841e24bed81e8a8e949/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cfcfd963b50a41fc656c84d3440ca6eecdccd6c552158ce790b8f2e33dfb5a9", size = 2704083, upload-time = "2025-11-05T19:06:50.205Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/af/4eec8f9ab9c27bcdb444460c72cf43011d176fc44c79d6e113094ca1e152/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a3a16972aa1cee38ea958470cd04ac9a2d5ac38fdcf77ab686611246220c158", size = 2959765, upload-time = "2025-11-05T19:06:53.62Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/11/3c/33f3374e4624e0e776f6b13b73c45a7ead7f9c4529f8369ed5bfcaa30cac/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4d5cfa168e74d08f8ba6d58a7e49bc7daef4d58951ec69b66b0d56f4927a68d", size = 3427031, upload-time = "2025-11-05T19:06:51.829Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/3f/1a192b93bb47c6b44cd98ba8cc1d3d2a9308f1bb700c3017e6352da11bda/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c007d277218a50db8839e599ed78e0fffe5130f614c3f6d93ae257f282071a29", size = 2953260, upload-time = "2025-11-05T19:06:55.406Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/f8/93b582cad3531797c3db7c2db5400fd841538ccddfd9f5e3df61be99a630/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8565d4f5a0638da1bffde29832ed63c9e695c558611053add3b2dc0b56c92dbc", size = 3127044, upload-time = "2025-11-05T19:06:59.553Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/10/4327dbf87f75ae813405fd9a9b4a5cde63d506ffed0a096a440a4cabd89c/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:cbaa3bda75ef0d8836e1f8cc84af62f971b1d756d740efc95c38c3e04c0bfde2", size = 2932931, upload-time = "2025-11-05T19:07:01.437Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/c8/1774eec4f6f360ef57618fb8f52e3d3af245b2491bd0297513aa09eec04b/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:772922a9bd24e133950fad71eb1550836f415a88e8c77870e12d0c3bd688ddc2", size = 2996140, upload-time = "2025-11-05T19:07:03.438Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/c3/3d1e01e2dba517a91760e4a03e4f20ffc75039a6fe584d0e6f9b5c78fd15/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:007b0476a1f331f8130783f901f1da6f5a7057af1a4891f1b6a31dec364189b5", size = 3205080, upload-time = "2025-11-05T19:07:05.078Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "25.0"
|
||||
|
||||
Reference in New Issue
Block a user