Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
19a21e9065 feat(dashboard): add light/dark mode toggle with warm parchment palette
Adds a theme system to the EXO dashboard with a "Mission Control, Dawn
Shift" light mode — warm parchment backgrounds (oklch(0.97 0.015 80))
and deep amber/brass accents (oklch(0.50 0.14 65)) that feel premium
rather than cold.

Changes:
- dashboard/src/lib/stores/theme.svelte.ts: new Svelte 5 rune store,
  persists choice to localStorage under 'exo-theme'
- dashboard/src/app.html: FOUC prevention — html starts as class="dark",
  inline script reads localStorage and switches to class="light" before
  first paint
- dashboard/src/routes/+layout.svelte: calls theme.init() on mount to
  sync rune state with the DOM class
- dashboard/src/lib/components/HeaderNav.svelte: sun/moon toggle button
  in the right nav area
- dashboard/src/app.css: full html.light palette + utility overrides
  (scrollbar, logo filter, graph links, scanlines, etc.)

No new npm dependencies — avoids mode-watcher entirely.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-18 12:36:51 -08:00
15 changed files with 369 additions and 51 deletions

View File

@@ -16,9 +16,10 @@
/* Gotham-inspired accent colors */
--exo-grid: oklch(0.25 0 0);
--exo-scanline: oklch(0.15 0 0);
--exo-glow-yellow: 0 0 20px oklch(0.85 0.18 85 / 0.3);
--exo-glow-yellow-strong: 0 0 40px oklch(0.85 0.18 85 / 0.5);
--exo-glow-yellow: oklch(0.85 0.18 85 / 0.3);
--exo-glow-yellow-strong: oklch(0.85 0.18 85 / 0.5);
--exo-bg-hover: oklch(0.18 0 0);
/* Theme Variables */
--radius: 0.375rem;
--background: var(--exo-black);
@@ -41,6 +42,237 @@
--ring: var(--exo-yellow);
}
/* ============================================================
LIGHT THEME — "Mission Control, Dawn Shift"
Warm parchment + deep amber. Applied when <html> has .light class.
============================================================ */
html.light {
/* EXO brand palette — warm amber shift */
--exo-black: oklch(0.97 0.015 80);
--exo-dark-gray: oklch(0.92 0.012 80);
--exo-medium-gray: oklch(0.83 0.009 78);
--exo-light-gray: oklch(0.50 0.018 75);
--exo-yellow: oklch(0.50 0.14 65);
--exo-yellow-darker: oklch(0.40 0.13 65);
--exo-yellow-glow: oklch(0.60 0.14 65);
--exo-grid: oklch(0.88 0.009 80);
--exo-scanline: oklch(0.93 0.010 80);
--exo-glow-yellow: oklch(0.50 0.14 65 / 0.12);
--exo-glow-yellow-strong: oklch(0.50 0.14 65 / 0.22);
--exo-bg-hover: oklch(0.89 0.010 80);
/* Semantic tokens */
--background: oklch(0.97 0.015 80);
--foreground: oklch(0.13 0.015 75);
--card: oklch(0.92 0.012 80);
--card-foreground: oklch(0.13 0.015 75);
--popover: oklch(0.95 0.012 80);
--popover-foreground: oklch(0.13 0.015 75);
--primary: oklch(0.50 0.14 65);
--primary-foreground: oklch(0.97 0.015 80);
--secondary: oklch(0.88 0.008 80);
--secondary-foreground: oklch(0.15 0.012 75);
--muted: oklch(0.90 0.009 80);
--muted-foreground: oklch(0.50 0.018 75);
--accent: oklch(0.88 0.008 80);
--accent-foreground: oklch(0.15 0.012 75);
--destructive: oklch(0.52 0.22 25);
--border: oklch(0.84 0.007 78);
--input: oklch(0.87 0.008 80);
--ring: oklch(0.50 0.14 65);
}
/* ============================================================
LIGHT MODE UTILITY OVERRIDES
============================================================ */
html.light {
& .text-white,
& .text-white\/90,
& .text-white\/80,
& .text-white\/70 {
color: var(--foreground) !important;
}
& .text-white\/60,
& .text-white\/50 {
color: color-mix(in oklch, var(--foreground) 60%, transparent) !important;
}
& .text-white\/40,
& .text-white\/30 {
color: color-mix(in oklch, var(--foreground) 38%, transparent) !important;
}
& .bg-black\/80,
& .bg-black\/60,
& .bg-black\/50,
& .bg-black\/40 {
background-color: oklch(0.90 0.010 80 / 0.7) !important;
}
& [class*="bg-exo-black/"] {
background-color: oklch(0.90 0.010 80 / 0.6) !important;
}
& [class*="shadow-black"] {
--tw-shadow-color: oklch(0.30 0.010 75 / 0.10) !important;
}
& ::-webkit-scrollbar-track {
background: oklch(0.93 0.010 80) !important;
}
& ::-webkit-scrollbar-thumb {
background: oklch(0.76 0.010 78) !important;
}
& ::-webkit-scrollbar-thumb:hover {
background: oklch(0.50 0.14 65 / 0.6) !important;
}
& .command-panel {
background: linear-gradient(
180deg,
oklch(0.94 0.012 80 / 0.96) 0%,
oklch(0.91 0.010 80 / 0.98) 100%
) !important;
border-color: oklch(0.82 0.008 78) !important;
box-shadow:
inset 0 1px 0 oklch(1 0 0 / 0.6),
0 4px 20px oklch(0.30 0.010 75 / 0.08) !important;
}
& .glow-text {
text-shadow:
0 0 12px oklch(0.50 0.14 65 / 0.20),
0 1px 3px oklch(0.30 0.010 75 / 0.12) !important;
}
& .grid-bg {
background-image:
linear-gradient(oklch(0.75 0.008 78 / 0.25) 1px, transparent 1px),
linear-gradient(90deg, oklch(0.75 0.008 78 / 0.25) 1px, transparent 1px) !important;
}
& .scanlines::before {
background: repeating-linear-gradient(
0deg,
transparent,
transparent 2px,
oklch(0.50 0.010 78 / 0.018) 2px,
oklch(0.50 0.010 78 / 0.018) 4px
) !important;
}
& .crt-screen {
background: radial-gradient(
ellipse at center,
oklch(0.95 0.012 80) 0%,
oklch(0.92 0.010 80) 50%,
oklch(0.89 0.009 80) 100%
) !important;
box-shadow:
inset 0 0 60px oklch(0.30 0.010 75 / 0.04),
0 0 30px oklch(0.50 0.14 65 / 0.04) !important;
}
& .graph-link {
stroke: oklch(0.50 0.018 75 / 0.45) !important;
filter: none !important;
}
& .graph-link-active {
stroke: oklch(0.50 0.14 65 / 0.75) !important;
filter: none !important;
}
& .shooting-stars {
display: none !important;
}
& img[alt="EXO"] {
filter: brightness(0) drop-shadow(0 0 6px oklch(0.30 0.010 75 / 0.10)) !important;
}
& .text-red-400 { color: oklch(0.52 0.22 25) !important; }
& .text-green-400 { color: oklch(0.48 0.17 155) !important; }
& .text-blue-200,
& .text-blue-300,
& .text-blue-400 { color: oklch(0.48 0.17 250) !important; }
& .bg-red-500\/10 { background-color: oklch(0.52 0.22 25 / 0.07) !important; }
& .bg-red-500\/20 { background-color: oklch(0.52 0.22 25 / 0.11) !important; }
& .bg-red-500\/30 { background-color: oklch(0.52 0.22 25 / 0.14) !important; }
& textarea,
& input[type="text"] { color: var(--foreground) !important; }
& textarea::placeholder,
& input::placeholder { color: oklch(0.50 0.012 78 / 0.55) !important; }
& .code-block-wrapper,
& .math-display-wrapper {
background: oklch(0.95 0.010 80) !important;
border-color: oklch(0.83 0.007 78) !important;
}
& .code-block-header,
& .math-display-header {
background: oklch(0.91 0.009 80) !important;
border-color: oklch(0.85 0.007 78) !important;
}
& .inline-code {
background: oklch(0.89 0.009 80) !important;
color: oklch(0.20 0.012 75) !important;
}
& blockquote { background: oklch(0.93 0.010 80) !important; }
& th {
background: oklch(0.90 0.009 80) !important;
border-color: oklch(0.80 0.007 78) !important;
}
& td { border-color: oklch(0.84 0.007 78) !important; }
& hr { border-color: oklch(0.84 0.007 78) !important; }
& .hljs { color: oklch(0.22 0.012 75) !important; }
& .hljs-keyword, & .hljs-selector-tag, & .hljs-literal, & .hljs-section, & .hljs-link {
color: oklch(0.45 0.18 300) !important;
}
& .hljs-string, & .hljs-title, & .hljs-name, & .hljs-type,
& .hljs-attribute, & .hljs-symbol, & .hljs-bullet, & .hljs-addition,
& .hljs-variable, & .hljs-template-tag, & .hljs-template-variable {
color: oklch(0.45 0.14 65) !important;
}
& .hljs-comment, & .hljs-quote, & .hljs-deletion, & .hljs-meta {
color: oklch(0.55 0.010 78) !important;
}
& .hljs-number, & .hljs-regexp, & .hljs-built_in {
color: oklch(0.45 0.15 160) !important;
}
& .hljs-function, & .hljs-class .hljs-title {
color: oklch(0.42 0.17 240) !important;
}
& .katex, & .katex .mord, & .katex .minner, & .katex .mop,
& .katex .mbin, & .katex .mrel, & .katex .mpunct {
color: oklch(0.15 0.012 75) !important;
}
& .katex .frac-line, & .katex .overline-line, & .katex .underline-line,
& .katex .hline, & .katex .rule {
border-color: oklch(0.25 0.012 75) !important;
background: oklch(0.25 0.012 75) !important;
}
& .katex svg { fill: oklch(0.25 0.012 75) !important; stroke: oklch(0.25 0.012 75) !important; }
& .katex svg path { stroke: oklch(0.25 0.012 75) !important; }
& .katex .mopen, & .katex .mclose,
& .katex .delimsizing, & [class^="katex .delim-size"] {
color: oklch(0.35 0.012 75) !important;
}
& .latex-proof { background: oklch(0.96 0.010 80) !important; border-left-color: oklch(0.72 0.010 78) !important; }
& .latex-proof-header { color: oklch(0.22 0.012 75) !important; }
& .latex-proof-content { color: oklch(0.15 0.012 75) !important; }
& .latex-proof-content::after { color: oklch(0.48 0.012 75) !important; }
& .latex-theorem { background: oklch(0.94 0.010 80) !important; border-color: oklch(0.80 0.008 78) !important; }
& .latex-diagram-placeholder {
background: oklch(0.96 0.010 80) !important;
border-color: oklch(0.80 0.008 78) !important;
color: oklch(0.38 0.012 75) !important;
}
}
@theme inline {
--radius-sm: calc(var(--radius) - 2px);
--radius-md: var(--radius);

View File

@@ -1,7 +1,15 @@
<!doctype html>
<html lang="en">
<html lang="en" class="dark">
<head>
<meta charset="utf-8" />
<script>
try {
if (localStorage.getItem('exo-theme') === 'light') {
document.documentElement.classList.remove('dark');
document.documentElement.classList.add('light');
}
} catch (_) {}
</script>
<link rel="icon" href="%sveltekit.assets%/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>EXO</title>

View File

@@ -1,5 +1,6 @@
<script lang="ts">
import { browser } from "$app/environment";
import { theme } from "$lib/stores/theme.svelte";
export let showHome = true;
export let onHome: (() => void) | null = null;
@@ -79,10 +80,48 @@
/>
</button>
<!-- Right: Home + Downloads -->
<!-- Right: Theme toggle + Home + Downloads -->
<div
class="absolute right-6 top-1/2 -translate-y-1/2 flex items-center gap-4"
>
<button
onclick={() => theme.toggle()}
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
title={theme.isLight ? "Switch to dark mode" : "Switch to light mode"}
aria-label={theme.isLight
? "Switch to dark mode"
: "Switch to light mode"}
>
{#if theme.isLight}
<svg
class="w-4 h-4 text-exo-light-gray"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M21 12.79A9 9 0 1111.21 3a7 7 0 009.79 9.79z"
/>
</svg>
{:else}
<svg
class="w-4 h-4 text-exo-light-gray"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<circle cx="12" cy="12" r="5" />
<path
stroke-linecap="round"
d="M12 1v2m0 18v2M4.22 4.22l1.42 1.42m12.72 12.72l1.42 1.42M1 12h2m18 0h2M4.22 19.78l1.42-1.42M18.36 5.64l1.42-1.42"
/>
</svg>
{/if}
</button>
{#if showHome}
<button
onclick={handleHome}

View File

@@ -0,0 +1,28 @@
import { browser } from "$app/environment";
let _isLight = $state(false);
export const theme = {
get isLight() {
return _isLight;
},
init() {
if (!browser) return;
_isLight = document.documentElement.classList.contains("light");
},
toggle() {
if (!browser) return;
_isLight = !_isLight;
if (_isLight) {
document.documentElement.classList.remove("dark");
document.documentElement.classList.add("light");
localStorage.setItem("exo-theme", "light");
} else {
document.documentElement.classList.remove("light");
document.documentElement.classList.add("dark");
localStorage.setItem("exo-theme", "dark");
}
},
};

View File

@@ -1,7 +1,13 @@
<script lang="ts">
import "../app.css";
import { onMount } from "svelte";
import { theme } from "$lib/stores/theme.svelte";
let { children } = $props();
onMount(() => {
theme.init();
});
</script>
<svelte:head>

View File

@@ -1,6 +1,7 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
@@ -21,7 +22,7 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -45,8 +46,8 @@ class DownloadCoordinator:
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -294,16 +295,15 @@ class DownloadCoordinator:
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self._system_id,
origin=self.node_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)

View File

@@ -1,10 +1,11 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
from typing import Iterator, Self
import anyio
from anyio.abc import TaskGroup
@@ -37,11 +38,12 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
async def create(cls, args: "Args") -> Self:
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -55,6 +57,9 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
@@ -63,6 +68,7 @@ class Node:
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
@@ -89,6 +95,7 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -126,6 +133,7 @@ class Node:
master,
api,
node_id,
event_index_counter,
args.offline,
)
@@ -205,6 +213,7 @@ class Node:
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
@@ -215,6 +224,7 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
@@ -232,6 +242,7 @@ class Node:
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -131,7 +131,7 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -183,7 +183,6 @@ class API:
) -> None:
self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
@@ -234,7 +233,6 @@ class API:
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
@@ -548,7 +546,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -893,7 +891,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -979,7 +977,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -1410,7 +1408,7 @@ class API:
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
for idx, event in self.event_buffer.drain_indexed():
@@ -1474,12 +1472,12 @@ class API:
while self.paused:
await self.paused_ev.wait()
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self._system_id, command=command)
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(

View File

@@ -29,7 +29,7 @@ from exo.shared.types.commands import (
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -90,8 +90,7 @@ class Master:
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
@@ -289,7 +288,7 @@ class Master:
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self._system_id, command=cmd
origin=self.node_id, command=cmd
)
)
generated_events.extend(transition_events)
@@ -416,7 +415,7 @@ class Master:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=self._system_id,
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
@@ -429,7 +428,7 @@ class Master:
# Convenience method since this line is ugly
await self.global_event_sender.send(
ForwarderEvent(
origin=self._system_id,
origin=self.node_id,
origin_idx=event.idx,
session=self.session_id,
event=event.event,

View File

@@ -15,7 +15,7 @@ from exo.shared.types.commands import (
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -75,12 +75,13 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=SystemId("Worker"),
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
@@ -107,7 +108,7 @@ async def test_master():
logger.info("inject a CreateInstance Command")
await command_sender.send(
ForwarderCommand(
origin=SystemId("API"),
origin=node_id,
command=(
PlaceInstance(
command_id=CommandId(),
@@ -132,7 +133,7 @@ async def test_master():
logger.info("inject a TextGeneration Command")
await command_sender.send(
ForwarderCommand(
origin=SystemId("API"),
origin=node_id,
command=(
TextGeneration(
command_id=CommandId(),

View File

@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import channel
# ======= #
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
# Pump local commands so our commands_seen is high before the round starts
for _ in range(50):
await co_tx.send(
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
)
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -100,10 +100,10 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: SystemId
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: SystemId
origin: NodeId
command: DownloadCommand

View File

@@ -25,10 +25,6 @@ class NodeId(Id):
pass
class SystemId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -166,6 +166,6 @@ class ForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: SystemId
origin: NodeId
session: SessionId
event: Event

View File

@@ -1,6 +1,7 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
@@ -16,7 +17,7 @@ from exo.shared.types.commands import (
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
@@ -63,12 +64,14 @@ class Worker:
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
@@ -83,8 +86,6 @@ class Worker:
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
@@ -131,7 +132,7 @@ class Worker:
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
event_id = f_event.event.event_id
@@ -211,7 +212,7 @@ class Worker:
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self._system_id,
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
@@ -311,7 +312,7 @@ class Worker:
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
origin=self.node_id,
command=RequestEventLog(since_idx=since_idx),
)
)
@@ -338,16 +339,15 @@ class Worker:
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self._system_id,
origin=self.node_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe