mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
full mlx caching implementation
This commit is contained in:
@@ -32,6 +32,7 @@ dependencies = [
|
||||
"transformers>=4.55.2",
|
||||
"cobs>=1.2.2",
|
||||
"loguru>=0.7.3",
|
||||
"textual>=5.3.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
14
run.sh
14
run.sh
@@ -33,16 +33,16 @@ fi
|
||||
# Configure MLX
|
||||
# ./configure_mlx.sh
|
||||
|
||||
# First command (worker) - changes based on replica flag
|
||||
if [ "$REPLICA" = true ]; then
|
||||
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export EXO_HOME=.exo; uv run exo-worker'\""
|
||||
else
|
||||
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c uv run exo-worker\""
|
||||
fi
|
||||
|
||||
# Second command (master) - changes based on replica flag
|
||||
if [ "$REPLICA" = true ]; then
|
||||
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export RUST_LOG=true EXO_RUN_AS_REPLICA=1 EXO_HOME=.exo API_PORT=8001; uv run exo-master'\""
|
||||
else
|
||||
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export RUST_LOG=true; uv run exo-master'\""
|
||||
fi
|
||||
|
||||
# First command (worker) - changes based on replica flag
|
||||
if [ "$REPLICA" = true ]; then
|
||||
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c bash -c 'export EXO_HOME=.exo; uv run exo-worker'\""
|
||||
else
|
||||
osascript -e "tell app \"Terminal\" to do script \"cd '$DIR'; nix develop -c uv run exo-worker\""
|
||||
fi
|
||||
@@ -1,26 +1,36 @@
|
||||
# pyright: reportAny=false
|
||||
|
||||
import asyncio
|
||||
import curses
|
||||
import time
|
||||
import json
|
||||
import argparse
|
||||
import sys
|
||||
from logging import Logger
|
||||
import time
|
||||
from dataclasses import is_dataclass, asdict
|
||||
from logging import getLogger
|
||||
from typing import List, Optional, Any, Sequence, Tuple
|
||||
|
||||
# Your existing imports — unchanged
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.db.sqlite.event_log_manager import EventLogManager, EventLogConfig
|
||||
from exo.shared.types.events.components import EventFromEventLog
|
||||
from exo.shared.types.events import Event
|
||||
|
||||
# Globals
|
||||
logger: Logger = Logger('helper_log')
|
||||
event_log_manager: Optional[EventLogManager] = None
|
||||
worker_mode: bool = False
|
||||
# --- Third-party UI (new) ---
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
from rich.panel import Panel
|
||||
from rich.console import RenderableType
|
||||
|
||||
# Worker-related event types
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.containers import Horizontal, Vertical
|
||||
from textual.widgets import Static, ListView, ListItem, Input, Footer, Label
|
||||
from textual.reactive import reactive
|
||||
from textual import on
|
||||
from textual.binding import Binding
|
||||
from textual.message import Message
|
||||
|
||||
logger = getLogger("helper_log")
|
||||
|
||||
# Worker-related event types (same set)
|
||||
WORKER_EVENT_TYPES = {
|
||||
'TaskCreated', 'TaskStateUpdated', 'TaskFailed', 'TaskDeleted',
|
||||
'ChunkGenerated',
|
||||
@@ -29,17 +39,19 @@ WORKER_EVENT_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
# ---------- Data / DB helpers (mostly your original logic) ----------
|
||||
|
||||
event_log_manager: Optional[EventLogManager] = None
|
||||
|
||||
async def init_db() -> None:
|
||||
global event_log_manager
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
event_log_manager = EventLogManager(EventLogConfig())
|
||||
await event_log_manager.initialize()
|
||||
|
||||
|
||||
async def get_events_since(since: int) -> Sequence[EventFromEventLog[Event]]:
|
||||
assert event_log_manager is not None
|
||||
# type: ignore[attr-defined, return-value]
|
||||
return await event_log_manager.global_events.get_events_since(since)
|
||||
|
||||
|
||||
async def load_all_events() -> List[EventFromEventLog[Event]]:
|
||||
events: List[EventFromEventLog[Event]] = []
|
||||
since = 0
|
||||
@@ -51,7 +63,6 @@ async def load_all_events() -> List[EventFromEventLog[Event]]:
|
||||
since += len(new_events)
|
||||
return events
|
||||
|
||||
|
||||
def compute_states(events: List[EventFromEventLog[Event]]) -> List[State]:
|
||||
states: List[State] = [State()]
|
||||
state = states[0]
|
||||
@@ -60,34 +71,95 @@ def compute_states(events: List[EventFromEventLog[Event]]) -> List[State]:
|
||||
states.append(state)
|
||||
return states
|
||||
|
||||
def filter_worker_state(state: State) -> dict:
|
||||
state_dict = json.loads(state.model_dump_json())
|
||||
return {
|
||||
'node_status': state_dict.get('node_status', {}),
|
||||
'instances': state_dict.get('instances', {}),
|
||||
'runners': state_dict.get('runners', {}),
|
||||
'tasks': state_dict.get('tasks', {}),
|
||||
'last_event_applied_idx': state_dict.get('last_event_applied_idx', 0)
|
||||
}
|
||||
|
||||
def print_event(event: EventFromEventLog[Event]) -> None:
|
||||
event_type_name = type(event.event).__name__
|
||||
event_type = event_type_name.replace('_', ' ').title()
|
||||
attributes = ', '.join(f"{key}={value!r}" for key,
|
||||
value in vars(event.event).items())
|
||||
print(f"[{event.idx_in_log}] {event_type}: {attributes}")
|
||||
def event_type_name(e: EventFromEventLog[Event]) -> str:
|
||||
return type(e.event).__name__
|
||||
|
||||
def is_worker_event(e: EventFromEventLog[Event]) -> bool:
|
||||
return event_type_name(e) in WORKER_EVENT_TYPES
|
||||
|
||||
def safe_json(obj: Any) -> str:
|
||||
"""Serialize unknown objects to JSON-ish string safely."""
|
||||
def to_serializable(x: Any):
|
||||
try:
|
||||
if is_dataclass(x):
|
||||
return asdict(x)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(x, (str, int, float, bool)) or x is None:
|
||||
return x
|
||||
if isinstance(x, dict):
|
||||
return {str(k): to_serializable(v) for k, v in x.items()}
|
||||
if isinstance(x, (list, tuple, set)):
|
||||
return [to_serializable(v) for v in x]
|
||||
try:
|
||||
json.dumps(x) # type: ignore
|
||||
return x
|
||||
except Exception:
|
||||
return repr(x)
|
||||
try:
|
||||
return json.dumps(to_serializable(obj), indent=2, ensure_ascii=False)
|
||||
except Exception:
|
||||
# Last resort
|
||||
return repr(obj)
|
||||
|
||||
def summarize_event_line(e: EventFromEventLog[Event], max_len: int = 160) -> Text:
|
||||
etype = event_type_name(e)
|
||||
attrs = vars(e.event)
|
||||
prefix = Text(f"[{e.idx_in_log}] ", style="bold dim")
|
||||
t = Text(etype, style="bold cyan")
|
||||
t = prefix + t + Text(": ", style="dim")
|
||||
first = True
|
||||
for k, v in attrs.items():
|
||||
if not first:
|
||||
t.append(", ", style="dim")
|
||||
first = False
|
||||
t.append(str(k), style="magenta")
|
||||
t.append("=")
|
||||
# Coarse coloring by type
|
||||
if isinstance(v, str):
|
||||
t.append(repr(v), style="green")
|
||||
elif isinstance(v, (int, float)):
|
||||
t.append(repr(v), style="yellow")
|
||||
elif isinstance(v, bool):
|
||||
t.append(repr(v), style="cyan")
|
||||
else:
|
||||
t.append(repr(v), style="")
|
||||
if len(t.plain) > max_len:
|
||||
t.truncate(max_len - 1)
|
||||
t.append("…", style="dim")
|
||||
return t
|
||||
|
||||
def event_detail_renderable(e: EventFromEventLog[Event]) -> RenderableType:
|
||||
payload = {
|
||||
"idx_in_log": e.idx_in_log,
|
||||
"event_type": event_type_name(e),
|
||||
"attributes": vars(e.event)
|
||||
}
|
||||
return Syntax(safe_json(payload), "json", word_wrap=True)
|
||||
|
||||
|
||||
async def non_tui_mode() -> None:
|
||||
# ---------- Non-TUI (stdout) mode, like your current script ----------
|
||||
|
||||
async def run_non_tui(worker_mode: bool) -> None:
|
||||
await init_db()
|
||||
events = await load_all_events()
|
||||
states = compute_states(events)
|
||||
final_state = states[-1]
|
||||
|
||||
if worker_mode:
|
||||
filtered_events = [e for e in events if type(
|
||||
e.event).__name__ in WORKER_EVENT_TYPES]
|
||||
filtered_events = [e for e in events if is_worker_event(e)]
|
||||
events = filtered_events
|
||||
# Recompute states? But states are cumulative, so perhaps just print filtered events and full state, or filter state too.
|
||||
state_dict = json.loads(final_state.model_dump_json())
|
||||
filtered_state = {
|
||||
'node_status': state_dict.get('node_status', {}),
|
||||
'instances': state_dict.get('instances', {}),
|
||||
'runners': state_dict.get('runners', {}),
|
||||
'tasks': state_dict.get('tasks', {}),
|
||||
'last_event_applied_idx': state_dict.get('last_event_applied_idx', 0)
|
||||
}
|
||||
filtered_state = filter_worker_state(final_state)
|
||||
print("Final State (filtered):")
|
||||
print(json.dumps(filtered_state, indent=2))
|
||||
else:
|
||||
@@ -95,464 +167,345 @@ async def non_tui_mode() -> None:
|
||||
print(final_state.model_dump_json(indent=2))
|
||||
|
||||
print("\nEvents:")
|
||||
for event in events:
|
||||
print_event(event)
|
||||
for e in events:
|
||||
etype = event_type_name(e)
|
||||
attrs = ', '.join(f"{k}={value!r}" for k, value in vars(e.event).items())
|
||||
print(f"[{e.idx_in_log}] {etype}: {attrs}")
|
||||
|
||||
|
||||
async def update_events(wrapped_events: List[EventFromEventLog[Event]], states: List[State],
|
||||
filtered_indices: Optional[List[int]] = None) -> bool:
|
||||
last_since = len(wrapped_events)
|
||||
new_wrapped = await get_events_since(last_since)
|
||||
if new_wrapped:
|
||||
last_len = len(wrapped_events)
|
||||
for nw in new_wrapped:
|
||||
state = states[-1]
|
||||
new_state = apply(state, nw)
|
||||
states.append(new_state)
|
||||
wrapped_events.extend(new_wrapped)
|
||||
if filtered_indices is not None:
|
||||
for k in range(last_len, len(wrapped_events)):
|
||||
if type(wrapped_events[k].event).__name__ in WORKER_EVENT_TYPES:
|
||||
filtered_indices.append(k)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def draw_state(win: Any, state: State, height: int, width: int, worker_mode: bool, state_scroll: int) -> int:
|
||||
win.clear()
|
||||
state_dict = json.loads(state.model_dump_json())
|
||||
if worker_mode:
|
||||
filtered_state = {
|
||||
'node_status': state_dict.get('node_status', {}),
|
||||
'instances': state_dict.get('instances', {}),
|
||||
'runners': state_dict.get('runners', {}),
|
||||
'tasks': state_dict.get('tasks', {}),
|
||||
'last_event_applied_idx': state_dict.get('last_event_applied_idx', 0)
|
||||
}
|
||||
state_pretty = json.dumps(filtered_state, indent=2)
|
||||
else:
|
||||
state_pretty = json.dumps(state_dict, indent=2)
|
||||
lines = state_pretty.split('\n')
|
||||
max_scroll = max(0, len(lines) - height)
|
||||
current_scroll = min(state_scroll, max_scroll)
|
||||
for i in range(height):
|
||||
line_idx = current_scroll + i
|
||||
if line_idx >= len(lines):
|
||||
break
|
||||
line = lines[line_idx]
|
||||
y = i
|
||||
x = 0
|
||||
leading_spaces = len(line) - len(line.lstrip())
|
||||
win.addstr(y, x, ' ' * leading_spaces)
|
||||
x += leading_spaces
|
||||
stripped = line.lstrip()
|
||||
if stripped.startswith('"'):
|
||||
end_key = stripped.find('": ')
|
||||
if end_key != -1:
|
||||
key_str = stripped[:end_key + 3] # include ":
|
||||
win.addstr(y, x, key_str, curses.color_pair(3))
|
||||
x += len(key_str)
|
||||
value_str = stripped[end_key + 3:]
|
||||
if value_str.startswith('"'):
|
||||
color = 2
|
||||
elif value_str.replace('.', '', 1).isdigit() or (
|
||||
value_str.startswith('-') and value_str[1:].replace('.', '', 1).isdigit()):
|
||||
color = 4
|
||||
elif value_str in ['true', 'false', 'null']:
|
||||
color = 5
|
||||
elif value_str.startswith('{') or value_str.startswith('[') or value_str.startswith(
|
||||
'}') or value_str.startswith(']'):
|
||||
color = 0
|
||||
else:
|
||||
color = 0
|
||||
win.addstr(y, x, value_str, curses.color_pair(color))
|
||||
else:
|
||||
win.addstr(y, x, stripped)
|
||||
else:
|
||||
win.addstr(y, x, stripped)
|
||||
win.refresh()
|
||||
return current_scroll
|
||||
|
||||
|
||||
def get_event_pairs(event: EventFromEventLog[Event]) -> List[Tuple[str, int]]:
|
||||
pairs: List[Tuple[str, int]] = []
|
||||
idx_str = f"[{event.idx_in_log}] "
|
||||
pairs.append((idx_str, 5))
|
||||
event_type_name = type(event.event).__name__
|
||||
event_type = event_type_name.replace('_', ' ').title()
|
||||
pairs.append((event_type, 1))
|
||||
pairs.append((": ", 0))
|
||||
attrs = vars(event.event)
|
||||
first = True
|
||||
for key, value in attrs.items():
|
||||
if not first:
|
||||
pairs.append((", ", 0))
|
||||
first = False
|
||||
pairs.append((key, 3))
|
||||
pairs.append(("=", 0))
|
||||
v_str = repr(value)
|
||||
if isinstance(value, str):
|
||||
color = 2
|
||||
elif isinstance(value, (int, float)):
|
||||
color = 4
|
||||
elif isinstance(value, bool):
|
||||
color = 5
|
||||
else:
|
||||
color = 6
|
||||
pairs.append((v_str, color))
|
||||
return pairs
|
||||
|
||||
|
||||
def calculate_event_lines(pairs: List[Tuple[str, int]], win_width: int, subsequent_indent: int) -> int:
|
||||
lines = 1
|
||||
x = 0
|
||||
for text, _ in pairs:
|
||||
i = 0
|
||||
while i < len(text):
|
||||
remaining = win_width - x
|
||||
part_len = min(len(text) - i, remaining)
|
||||
i += part_len
|
||||
x += part_len
|
||||
if i < len(text):
|
||||
lines += 1
|
||||
x = subsequent_indent
|
||||
return lines
|
||||
|
||||
|
||||
def render_event(win: Any, start_y: int, pairs: List[Tuple[str, int]], is_bold: bool, win_width: int,
|
||||
subsequent_indent: int) -> int:
|
||||
y = start_y
|
||||
x = 0
|
||||
for text, color in pairs:
|
||||
attr = curses.color_pair(color) | (curses.A_BOLD if is_bold else 0)
|
||||
i = 0
|
||||
while i < len(text):
|
||||
remaining = win_width - x
|
||||
part_len = min(len(text) - i, remaining)
|
||||
part = text[i:i + part_len]
|
||||
try:
|
||||
win.addstr(y, x, part, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
i += part_len
|
||||
x += part_len
|
||||
if i < len(text):
|
||||
y += 1
|
||||
if y >= win.getmaxyx()[0]:
|
||||
return y
|
||||
x = subsequent_indent
|
||||
if x > 0:
|
||||
y += 1
|
||||
return y
|
||||
|
||||
|
||||
def draw_events(win: Any, events_list: List[EventFromEventLog[Event]], current_events: int, height: int) -> None:
|
||||
win.clear()
|
||||
if len(events_list) == 0:
|
||||
win.addstr(0, 0, "No events")
|
||||
win.refresh()
|
||||
return
|
||||
win_width = win.getmaxyx()[1]
|
||||
current_event = events_list[current_events]
|
||||
current_pairs = get_event_pairs(current_event)
|
||||
subsequent_indent = len(f"[{current_event.idx_in_log}] ")
|
||||
lines_current = calculate_event_lines(
|
||||
current_pairs, win_width, subsequent_indent)
|
||||
if lines_current > height:
|
||||
render_event(win, 0, current_pairs, True, win_width, subsequent_indent)
|
||||
win.refresh()
|
||||
return
|
||||
|
||||
target_above = (height - lines_current) // 2
|
||||
target_below = height - lines_current - target_above
|
||||
|
||||
# Collect previous events
|
||||
prev_events: List[int] = []
|
||||
remaining = target_above
|
||||
i = current_events - 1
|
||||
while i >= 0 and remaining > 0:
|
||||
event = events_list[i]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
lines = calculate_event_lines(pairs, win_width, indent)
|
||||
if lines <= remaining:
|
||||
remaining -= lines
|
||||
prev_events.append(i)
|
||||
i -= 1
|
||||
else:
|
||||
break
|
||||
prev_events.reverse()
|
||||
|
||||
# Collect next events
|
||||
next_events: List[int] = []
|
||||
remaining = target_below
|
||||
j = current_events + 1
|
||||
while j < len(events_list) and remaining > 0:
|
||||
event = events_list[j]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
lines = calculate_event_lines(pairs, win_width, indent)
|
||||
if lines <= remaining:
|
||||
remaining -= lines
|
||||
next_events.append(j)
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# Calculate total lines
|
||||
total_lines = lines_current
|
||||
for idx in prev_events:
|
||||
event = events_list[idx]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
total_lines += calculate_event_lines(pairs, win_width, indent)
|
||||
for idx in next_events:
|
||||
event = events_list[idx]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
total_lines += calculate_event_lines(pairs, win_width, indent)
|
||||
|
||||
padding = (height - total_lines) // 2 if total_lines < height else 0
|
||||
|
||||
y = padding
|
||||
# Draw prev
|
||||
for idx in prev_events:
|
||||
event = events_list[idx]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
y = render_event(win, y, pairs, False, win_width, indent)
|
||||
|
||||
# Draw current
|
||||
y = render_event(win, y, current_pairs, True, win_width, subsequent_indent)
|
||||
|
||||
# Draw next
|
||||
for idx in next_events:
|
||||
event = events_list[idx]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
y = render_event(win, y, pairs, False, win_width, indent)
|
||||
|
||||
win.refresh()
|
||||
|
||||
|
||||
def draw_status(win: Any, realtime: bool, current: int, total_events: int) -> None:
|
||||
win.clear()
|
||||
mode = "Realtime" if realtime else "Timetravel"
|
||||
win.addstr(0, 0,
|
||||
f"Mode: {mode} | Current event: {current} / {total_events} | Arrows: navigate events, [/]: scroll state, g: goto, r: toggle realtime, q: quit")
|
||||
win.refresh()
|
||||
|
||||
|
||||
def get_input(stdscr: Any, prompt: str) -> str:
|
||||
curses.echo()
|
||||
stdscr.addstr(0, 0, prompt)
|
||||
stdscr.refresh()
|
||||
input_str = stdscr.getstr(0, len(prompt), 20).decode('utf-8')
|
||||
curses.noecho()
|
||||
return input_str
|
||||
|
||||
|
||||
def get_key(win: Any) -> Any:
|
||||
ch = win.getch()
|
||||
if ch == -1:
|
||||
return -1
|
||||
if ch == 27:
|
||||
ch2 = win.getch()
|
||||
if ch2 == -1:
|
||||
return 27
|
||||
if ch2 == 91:
|
||||
ch3 = win.getch()
|
||||
if ch3 == -1:
|
||||
return -1
|
||||
if ch3 == 65:
|
||||
return curses.KEY_UP
|
||||
if ch3 == 66:
|
||||
return curses.KEY_DOWN
|
||||
if ch3 == 53:
|
||||
ch4 = win.getch()
|
||||
if ch4 == 126:
|
||||
return curses.KEY_PPAGE
|
||||
if ch3 == 54:
|
||||
ch4 = win.getch()
|
||||
if ch4 == 126:
|
||||
return curses.KEY_NPAGE
|
||||
if ch3 == 49:
|
||||
ch4 = win.getch()
|
||||
if ch4 == -1:
|
||||
return -1
|
||||
if ch4 == 59:
|
||||
ch5 = win.getch()
|
||||
if ch5 == -1:
|
||||
return -1
|
||||
if ch5 == 53:
|
||||
ch6 = win.getch()
|
||||
if ch6 == -1:
|
||||
return -1
|
||||
if ch6 == 65:
|
||||
return 'CTRL_UP'
|
||||
if ch6 == 66:
|
||||
return 'CTRL_DOWN'
|
||||
return ch
|
||||
|
||||
|
||||
def tui(stdscr: Any) -> None:
|
||||
curses.start_color()
|
||||
curses.init_pair(1, curses.COLOR_BLUE, curses.COLOR_BLACK)
|
||||
curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK)
|
||||
curses.init_pair(3, curses.COLOR_MAGENTA, curses.COLOR_BLACK)
|
||||
curses.init_pair(4, curses.COLOR_YELLOW, curses.COLOR_BLACK)
|
||||
curses.init_pair(5, curses.COLOR_CYAN, curses.COLOR_BLACK)
|
||||
curses.init_pair(6, curses.COLOR_WHITE, curses.COLOR_BLACK)
|
||||
curses.use_default_colors()
|
||||
stdscr.timeout(100)
|
||||
curses.curs_set(0)
|
||||
|
||||
wrapped_events: List[EventFromEventLog[Event]] = []
|
||||
states: List[State] = [State()]
|
||||
asyncio.run(init_db())
|
||||
asyncio.run(update_events(wrapped_events, states)) # Initial load
|
||||
|
||||
filtered_indices: Optional[List[int]] = None
|
||||
current_filtered: int = -1
|
||||
current: int = -1
|
||||
if worker_mode:
|
||||
filtered_indices = [i for i in range(len(wrapped_events)) if
|
||||
type(wrapped_events[i].event).__name__ in WORKER_EVENT_TYPES]
|
||||
current_filtered = len(filtered_indices) - \
|
||||
1 if filtered_indices else -1
|
||||
else:
|
||||
current = len(wrapped_events) - 1 if wrapped_events else -1
|
||||
|
||||
realtime: bool = False
|
||||
last_update: float = time.time()
|
||||
update_interval: float = 1.0
|
||||
state_scroll: int = 0
|
||||
|
||||
while True:
|
||||
height, width = stdscr.getmaxyx()
|
||||
status_height = 1
|
||||
pane_height = height - status_height
|
||||
pane_width = width // 2
|
||||
|
||||
state_win = curses.newwin(pane_height, pane_width, 0, 0)
|
||||
events_win = curses.newwin(
|
||||
pane_height, width - pane_width, 0, pane_width)
|
||||
status_win = curses.newwin(status_height, width, pane_height, 0)
|
||||
# ---------- Textual TUI ----------
|
||||
|
||||
class StateView(Static):
|
||||
"""Left pane: shows state JSON, with optional worker filter."""
|
||||
def update_state(self, state: State, worker_mode: bool, index_in_log_for_status: Optional[int]) -> None:
|
||||
if worker_mode:
|
||||
assert filtered_indices is not None
|
||||
current_original = filtered_indices[current_filtered] if current_filtered >= 0 else -1
|
||||
events_list = [wrapped_events[i] for i in filtered_indices]
|
||||
current_events = current_filtered
|
||||
data = filter_worker_state(state)
|
||||
json_str = json.dumps(data, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
current_original = current
|
||||
events_list = wrapped_events
|
||||
current_events = current
|
||||
json_str = state.model_dump_json(indent=2)
|
||||
syntax = Syntax(json_str, "json", word_wrap=True)
|
||||
title = f"State after event #{index_in_log_for_status}" if index_in_log_for_status is not None else "Initial State"
|
||||
self.update(Panel(syntax, title=title, border_style="cyan"))
|
||||
|
||||
state_idx = current_original + 1 if current_original >= 0 else 0
|
||||
state_scroll = draw_state(
|
||||
state_win, states[state_idx], pane_height, pane_width, worker_mode, state_scroll)
|
||||
draw_events(events_win, events_list, current_events, pane_height)
|
||||
total_events = len(wrapped_events) - 1 if wrapped_events else -1
|
||||
draw_status(status_win, realtime,
|
||||
current_original if worker_mode else current, total_events)
|
||||
class EventListItem(ListItem):
|
||||
def __init__(self, e: EventFromEventLog[Event]) -> None:
|
||||
super().__init__(Static(summarize_event_line(e)))
|
||||
self._event = e
|
||||
|
||||
key = get_key(stdscr)
|
||||
if key != -1:
|
||||
if key == curses.KEY_UP:
|
||||
if worker_mode and current_filtered > 0:
|
||||
current_filtered -= 1
|
||||
elif not worker_mode and current > 0:
|
||||
current -= 1
|
||||
elif key == 'CTRL_UP':
|
||||
if worker_mode:
|
||||
current_filtered = max(0, current_filtered - 5)
|
||||
else:
|
||||
current = max(0, current - 5)
|
||||
elif key == curses.KEY_DOWN:
|
||||
assert filtered_indices is not None
|
||||
if worker_mode and current_filtered < len(filtered_indices) - 1:
|
||||
current_filtered += 1
|
||||
elif not worker_mode and current < len(wrapped_events) - 1:
|
||||
current += 1
|
||||
elif key == 'CTRL_DOWN':
|
||||
assert filtered_indices is not None
|
||||
if worker_mode:
|
||||
current_filtered = min(
|
||||
len(filtered_indices) - 1, current_filtered + 5)
|
||||
else:
|
||||
current = min(len(wrapped_events) - 1, current + 5)
|
||||
elif key == ord('['):
|
||||
state_scroll = max(0, state_scroll - pane_height // 2)
|
||||
elif key == ord(']'):
|
||||
state_scroll += pane_height // 2 # clamped in draw_state
|
||||
elif key == ord('q'):
|
||||
@property
|
||||
def wrapped_event(self) -> EventFromEventLog[Event]:
|
||||
return self._event
|
||||
|
||||
class EventDetail(Static):
|
||||
"""Right-bottom: details of the selected event."""
|
||||
def show_event(self, e: Optional[EventFromEventLog[Event]]) -> None:
|
||||
if e is None:
|
||||
self.update(Panel(Text("No event selected.", style="dim"), title="Event Details"))
|
||||
else:
|
||||
self.update(Panel(event_detail_renderable(e), title=f"Event #{e.idx_in_log} • {event_type_name(e)}", border_style="magenta"))
|
||||
|
||||
class StatusBar(Static):
|
||||
def set_status(self, realtime: bool, total_events: int, current_idx_in_log: Optional[int]) -> None:
|
||||
mode = "Realtime" if realtime else "Timetravel"
|
||||
parts = [
|
||||
f"[{mode}]",
|
||||
f"Events: {total_events}",
|
||||
]
|
||||
if current_idx_in_log is not None:
|
||||
parts.append(f"Current: #{current_idx_in_log}")
|
||||
parts.append("Keys: ↑/↓ Select • PgUp/PgDn Scroll • Ctrl+↑/↓ ±5 • [/] State PgUp/PgDn • g Goto • r Realtime • q Quit")
|
||||
self.update(Text(" ".join(parts), style="dim"))
|
||||
|
||||
|
||||
class GotoPrompt(Static):
|
||||
"""Simple inline goto prompt (appears above Footer)."""
|
||||
class Submitted(Message):
|
||||
def __init__(self, value: Optional[int]) -> None:
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label("Go to event id (idx_in_log):", id="goto-label")
|
||||
yield Input(placeholder="e.g., 123", id="goto-input")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.query_one(Input).focus()
|
||||
|
||||
@on(Input.Submitted)
|
||||
def _submitted(self, event: Input.Submitted) -> None:
|
||||
text = (event.value or "").strip()
|
||||
try:
|
||||
value = int(text)
|
||||
except ValueError:
|
||||
value = None
|
||||
self.post_message(self.Submitted(value))
|
||||
|
||||
|
||||
class EventLogApp(App):
|
||||
CSS = """
|
||||
Screen {
|
||||
layout: vertical;
|
||||
}
|
||||
#main {
|
||||
height: 1fr;
|
||||
}
|
||||
#left {
|
||||
width: 60%;
|
||||
}
|
||||
#right {
|
||||
width: 40%;
|
||||
}
|
||||
#events {
|
||||
height: 3fr;
|
||||
}
|
||||
#detail {
|
||||
height: 2fr;
|
||||
border: tall;
|
||||
}
|
||||
#status {
|
||||
height: 1;
|
||||
padding: 0 1;
|
||||
}
|
||||
#goto {
|
||||
dock: bottom;
|
||||
height: 3;
|
||||
padding: 1 2;
|
||||
background: $panel;
|
||||
border: round $accent;
|
||||
}
|
||||
"""
|
||||
|
||||
BINDINGS = [
|
||||
Binding("q", "quit", "Quit"),
|
||||
Binding("r", "toggle_realtime", "Realtime"),
|
||||
Binding("[", "state_page_up", "State PgUp"),
|
||||
Binding("]", "state_page_down", "State PgDn"),
|
||||
Binding("g", "prompt_goto", "Goto"),
|
||||
Binding("ctrl+up", "jump_up", "Jump Up"),
|
||||
Binding("ctrl+down", "jump_down", "Jump Down"),
|
||||
]
|
||||
|
||||
# Reactive state
|
||||
realtime: reactive[bool] = reactive(False)
|
||||
worker_mode: bool
|
||||
|
||||
# Data
|
||||
wrapped_events: List[EventFromEventLog[Event]]
|
||||
states: List[State]
|
||||
filtered_indices: Optional[List[int]] # maps filtered idx -> original idx
|
||||
update_interval: float = 1.0
|
||||
_poll_timer = None
|
||||
|
||||
def __init__(self, worker_mode: bool) -> None:
|
||||
super().__init__()
|
||||
self.worker_mode = worker_mode
|
||||
self.wrapped_events = []
|
||||
self.states = [State()]
|
||||
self.filtered_indices = None
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
await init_db()
|
||||
await self._initial_load()
|
||||
# periodic polling for new events
|
||||
self._poll_timer = self.set_interval(self.update_interval, self._tick_poll)
|
||||
# Put list selection at end (last event) by default
|
||||
self._select_last()
|
||||
|
||||
async def _initial_load(self) -> None:
|
||||
self.wrapped_events = await load_all_events()
|
||||
self.states = compute_states(self.wrapped_events)
|
||||
|
||||
# Build filtered view if needed
|
||||
if self.worker_mode:
|
||||
self.filtered_indices = [i for i, e in enumerate(self.wrapped_events) if is_worker_event(e)]
|
||||
else:
|
||||
self.filtered_indices = None
|
||||
|
||||
# Populate the ListView
|
||||
lv = self.query_one("#events", ListView)
|
||||
lv.clear()
|
||||
events_to_show = self._view_events()
|
||||
for e in events_to_show:
|
||||
lv.append(EventListItem(e))
|
||||
|
||||
# Update left state & details
|
||||
self._refresh_views()
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
# Layout: [Header optional] -> main Horizontal -> Status bar + Footer
|
||||
with Horizontal(id="main"):
|
||||
with Vertical(id="left"):
|
||||
yield StateView(id="state")
|
||||
with Vertical(id="right"):
|
||||
yield ListView(id="events")
|
||||
yield EventDetail(id="detail")
|
||||
yield StatusBar(id="status")
|
||||
yield Footer()
|
||||
|
||||
def _current_original_index(self) -> int:
|
||||
lv = self.query_one("#events", ListView)
|
||||
idx = lv.index
|
||||
if idx is None or idx < 0:
|
||||
return -1
|
||||
if self.filtered_indices is not None:
|
||||
if idx >= len(self.filtered_indices):
|
||||
return -1
|
||||
return self.filtered_indices[idx]
|
||||
return idx
|
||||
|
||||
def _view_events(self) -> List[EventFromEventLog[Event]]:
|
||||
if self.filtered_indices is not None:
|
||||
return [self.wrapped_events[i] for i in self.filtered_indices]
|
||||
return self.wrapped_events
|
||||
|
||||
def _select_last(self) -> None:
|
||||
lv = self.query_one("#events", ListView)
|
||||
n = len(lv.children)
|
||||
if n:
|
||||
lv.index = n - 1
|
||||
|
||||
def _refresh_views(self) -> None:
|
||||
# Update State pane and Detail pane and Status bar
|
||||
original_idx = self._current_original_index()
|
||||
state_idx = (original_idx + 1) if original_idx >= 0 else 0
|
||||
state = self.states[state_idx]
|
||||
state_view = self.query_one("#state", StateView)
|
||||
idx_in_log = None
|
||||
if original_idx >= 0:
|
||||
idx_in_log = self.wrapped_events[original_idx].idx_in_log
|
||||
state_view.update_state(state, self.worker_mode, idx_in_log)
|
||||
|
||||
# Detail pane
|
||||
detail = self.query_one("#detail", EventDetail)
|
||||
current_event = self.wrapped_events[original_idx] if original_idx >= 0 else None
|
||||
detail.show_event(current_event)
|
||||
|
||||
# Status bar
|
||||
status = self.query_one("#status", StatusBar)
|
||||
total_events = len(self.wrapped_events)
|
||||
status.set_status(self.realtime, total_events, current_event.idx_in_log if current_event else None)
|
||||
|
||||
async def _poll_once(self) -> bool:
|
||||
"""Fetch and append new events; return True if updated."""
|
||||
last_since = len(self.wrapped_events)
|
||||
new_wrapped = await get_events_since(last_since)
|
||||
if not new_wrapped:
|
||||
return False
|
||||
|
||||
# Extend states incrementally (avoid recomputing all)
|
||||
for nw in new_wrapped:
|
||||
state = self.states[-1]
|
||||
self.states.append(apply(state, nw))
|
||||
|
||||
start_len = len(self.wrapped_events)
|
||||
self.wrapped_events.extend(new_wrapped)
|
||||
|
||||
# Update filtered mapping and UI list
|
||||
lv = self.query_one("#events", ListView)
|
||||
if self.worker_mode:
|
||||
if self.filtered_indices is None:
|
||||
self.filtered_indices = []
|
||||
for k in range(start_len, len(self.wrapped_events)):
|
||||
if is_worker_event(self.wrapped_events[k]):
|
||||
self.filtered_indices.append(k)
|
||||
lv.append(EventListItem(self.wrapped_events[k]))
|
||||
else:
|
||||
for k in range(start_len, len(self.wrapped_events)):
|
||||
lv.append(EventListItem(self.wrapped_events[k]))
|
||||
|
||||
# Auto-follow the tail in realtime mode
|
||||
if self.realtime:
|
||||
self._select_last()
|
||||
|
||||
# Refresh panes
|
||||
self._refresh_views()
|
||||
return True
|
||||
|
||||
def _tick_poll(self) -> None:
|
||||
# called by timer; schedule the async poll
|
||||
asyncio.create_task(self._poll_once())
|
||||
|
||||
# ------ Actions / key handlers ------
|
||||
def action_quit(self) -> None:
|
||||
self.exit()
|
||||
|
||||
def action_toggle_realtime(self) -> None:
|
||||
self.realtime = not self.realtime
|
||||
if self.realtime:
|
||||
self._select_last()
|
||||
self._refresh_views()
|
||||
|
||||
def action_state_page_up(self) -> None:
|
||||
state_view = self.query_one("#state", StateView)
|
||||
state_view.scroll_page_up()
|
||||
|
||||
def action_state_page_down(self) -> None:
|
||||
state_view = self.query_one("#state", StateView)
|
||||
state_view.scroll_page_down()
|
||||
|
||||
def action_jump_up(self) -> None:
|
||||
lv = self.query_one("#events", ListView)
|
||||
if lv.children:
|
||||
lv.index = max(0, (lv.index or 0) - 5)
|
||||
self._refresh_views()
|
||||
|
||||
def action_jump_down(self) -> None:
|
||||
lv = self.query_one("#events", ListView)
|
||||
if lv.children:
|
||||
lv.index = min(len(lv.children) - 1, (lv.index or 0) + 5)
|
||||
self._refresh_views()
|
||||
|
||||
def action_prompt_goto(self) -> None:
|
||||
# mount a small prompt near bottom
|
||||
if self.query("#goto"):
|
||||
return
|
||||
prompt = GotoPrompt(id="goto")
|
||||
self.mount(prompt)
|
||||
|
||||
@on(GotoPrompt.Submitted)
|
||||
def _on_goto_submitted(self, msg: GotoPrompt.Submitted) -> None:
|
||||
# Remove prompt
|
||||
for node in self.query("#goto"):
|
||||
node.remove()
|
||||
|
||||
if msg.value is None:
|
||||
return
|
||||
|
||||
target = msg.value
|
||||
# find in current view's idx_in_log
|
||||
events_to_show = self._view_events()
|
||||
lv = self.query_one("#events", ListView)
|
||||
for i, e in enumerate(events_to_show):
|
||||
if e.idx_in_log == target:
|
||||
lv.index = i
|
||||
self._refresh_views()
|
||||
break
|
||||
elif key == ord('r'):
|
||||
realtime = not realtime
|
||||
if realtime:
|
||||
assert filtered_indices is not None
|
||||
if worker_mode:
|
||||
current_filtered = len(
|
||||
filtered_indices) - 1 if filtered_indices else -1
|
||||
else:
|
||||
current = len(wrapped_events) - \
|
||||
1 if wrapped_events else -1
|
||||
state_scroll = 0
|
||||
elif key == ord('g'):
|
||||
stdscr.timeout(-1) # block for input
|
||||
input_str = get_input(status_win, "Go to event: ")
|
||||
try:
|
||||
goto = int(input_str)
|
||||
if worker_mode:
|
||||
assert filtered_indices is not None
|
||||
for i, orig in enumerate(filtered_indices):
|
||||
if wrapped_events[orig].idx_in_log == goto:
|
||||
current_filtered = i
|
||||
state_scroll = 0
|
||||
break
|
||||
else:
|
||||
for i in range(len(wrapped_events)):
|
||||
if wrapped_events[i].idx_in_log == goto:
|
||||
current = i
|
||||
state_scroll = 0
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
stdscr.timeout(100)
|
||||
status_win.clear()
|
||||
status_win.refresh()
|
||||
|
||||
if realtime and time.time() - last_update > update_interval:
|
||||
updated = asyncio.run(update_events(
|
||||
wrapped_events, states, filtered_indices if worker_mode else None))
|
||||
if updated:
|
||||
assert filtered_indices is not None
|
||||
if worker_mode:
|
||||
current_filtered = len(filtered_indices) - 1
|
||||
else:
|
||||
current = len(wrapped_events) - 1
|
||||
state_scroll = 0
|
||||
last_update = time.time()
|
||||
@on(ListView.Highlighted, "#events")
|
||||
@on(ListView.Selected, "#events")
|
||||
def _on_event_selected(self, *_: Any) -> None:
|
||||
# Update panes when selection changes
|
||||
self._refresh_views()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Read and display events from the event log')
|
||||
# ---------- Entrypoint ----------
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description='Read and display events from the event log (Textual UI)')
|
||||
parser.add_argument('--worker', action='store_true',
|
||||
help='Only show worker-related events (task, streaming, instance, runner status)')
|
||||
parser.add_argument('--no-ui', action='store_true',
|
||||
help='Print to stdout (non-interactive), like the original non-TUI mode')
|
||||
args = parser.parse_args()
|
||||
|
||||
worker_mode = args.worker
|
||||
# Non-interactive fallback if no TTY or user requests it
|
||||
if args.no_ui or not sys.stdout.isatty():
|
||||
asyncio.run(run_non_tui(worker_mode=args.worker))
|
||||
return
|
||||
|
||||
if not sys.stdout.isatty():
|
||||
asyncio.run(non_tui_mode())
|
||||
else:
|
||||
try:
|
||||
curses.wrapper(tui)
|
||||
except curses.error as e:
|
||||
if "could not find terminal" in str(e):
|
||||
print("Error: Could not find terminal. Falling back to non-TUI mode.")
|
||||
asyncio.run(non_tui_mode())
|
||||
else:
|
||||
raise
|
||||
# TUI mode
|
||||
app = EventLogApp(worker_mode=args.worker)
|
||||
app.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -60,19 +60,16 @@ def mlx_setup(
|
||||
target_cache = int(1.10 * (model_bytes + kv_bytes)) # +10% slack
|
||||
target_cache = min(target_cache, int(cache_frac_of_mrwss * mrwss))
|
||||
target_cache = min(target_cache, memsize)
|
||||
|
||||
runner_print(f"{target_cache=}")
|
||||
|
||||
mx.set_cache_limit(max(target_cache, 0))
|
||||
return
|
||||
|
||||
# Optional hard cap (keeps total MLX usage under control)
|
||||
with contextlib.suppress(Exception):
|
||||
mx.set_memory_limit(int(0.85 * mrwss))
|
||||
|
||||
# Wiring: off by default; if you re‑enable, wire at most a small fraction.
|
||||
if wired_frac_of_mrwss > 0.0:
|
||||
target_wired = min(int(wired_frac_of_mrwss * mrwss), int(0.5 * model_bytes))
|
||||
target_wired = int(wired_frac_of_mrwss * mrwss)
|
||||
target_wired = min(target_wired, target_cache) # don’t wire more than cache
|
||||
|
||||
runner_print(f"{target_wired=}")
|
||||
with contextlib.suppress(Exception): # older macOS won’t have this
|
||||
mx.set_wired_limit(max(target_wired, 0))
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def logger_setup(log_file: Path, verbosity: int = 0):
|
||||
if verbosity == 0:
|
||||
_ = logger.add( # type: ignore
|
||||
sys.__stderr__, # type: ignore
|
||||
format="[ {time:hh:mmA} | <level>{level: <8}</level>] <level>{message}</level>",
|
||||
format="[ {time:hh:mm:ss.SSSSA} | <level>{level: <8}</level>] <level>{message}</level>",
|
||||
level="INFO",
|
||||
colorize=True,
|
||||
enqueue=True,
|
||||
@@ -29,7 +29,7 @@ def logger_setup(log_file: Path, verbosity: int = 0):
|
||||
elif verbosity == 1:
|
||||
_ = logger.add( # type: ignore
|
||||
sys.__stderr__, # type: ignore
|
||||
format="[ {time:hh:mmA} | <level>{level: <8}</level>] <level>{message}</level>",
|
||||
format="[ {time:hh:mm:ss.SSSSA} | <level>{level: <8}</level>] <level>{message}</level>",
|
||||
level="INFO",
|
||||
colorize=True,
|
||||
enqueue=True,
|
||||
|
||||
@@ -58,16 +58,17 @@ def runner_write_response(obj: RunnerResponse) -> None:
|
||||
|
||||
async def supervisor_read_response(
|
||||
proc: asyncio.subprocess.Process,
|
||||
) -> RunnerResponse | None:
|
||||
) -> RunnerResponse:
|
||||
assert proc.stdout is not None, (
|
||||
"proc.stdout should not be None when created with stdout=PIPE"
|
||||
)
|
||||
line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=180)
|
||||
# TODO: We could put a timeout on this if we decide to send heartbeats from the runner.
|
||||
# This lets us handle cases where the process dies at some point not during an inference.
|
||||
line_bytes: bytes = await proc.stdout.readline()
|
||||
if not line_bytes:
|
||||
raise EOFError('No more data to read when reading response from runner.')
|
||||
line: str = line_bytes.decode("utf-8").strip()
|
||||
|
||||
if not line:
|
||||
return None
|
||||
|
||||
try:
|
||||
return RunnerResponseTypeAdapter.validate_json(line)
|
||||
except Exception as err:
|
||||
|
||||
@@ -112,7 +112,7 @@ async def main():
|
||||
model_shard_meta = setup_message.model_shard_meta
|
||||
hosts = setup_message.hosts
|
||||
|
||||
mlx_setup(int(get_weights_size_kb(model_shard_meta) // 2**10))
|
||||
mlx_setup(int(get_weights_size_kb(model_shard_meta) // 2**10), cache_frac_of_mrwss=0.8, wired_frac_of_mrwss=0.8)
|
||||
|
||||
# For testing - these are fake break conditions
|
||||
if model_shard_meta.immediate_exception:
|
||||
|
||||
@@ -114,6 +114,13 @@ class RunnerSupervisor:
|
||||
"""
|
||||
Read from the queue with a timeout, but also check if the read_task has failed.
|
||||
"""
|
||||
try:
|
||||
assert not self.read_task.done()
|
||||
except AssertionError as e_assert:
|
||||
e = self.read_task.exception()
|
||||
assert e is not None
|
||||
raise e from e_assert
|
||||
|
||||
queue_task = asyncio.create_task(self.read_queue.get())
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
@@ -137,13 +144,14 @@ class RunnerSupervisor:
|
||||
return response
|
||||
|
||||
if self.read_task in done:
|
||||
await self.read_task # Re-raises any exception from read_task
|
||||
logger.error(
|
||||
"Unreachable code run. We should have raised an error on the read_task being done."
|
||||
)
|
||||
|
||||
try:
|
||||
await self.read_task # Re-raises any exception from read_task
|
||||
except Exception:
|
||||
raise # bubble up exception
|
||||
raise RunnerError("RunnerStopped", "Runner read loop terminated unexpectedly before any response.", "")
|
||||
|
||||
# if we haven't read from the queue, we have timed out.
|
||||
await self.astop()
|
||||
await self.astop() # TODO: This could be handled by the called or _read_with_error_check - as we don't want a false Timeout to bring the whole runner down.
|
||||
raise asyncio.TimeoutError()
|
||||
|
||||
async def stream_response(
|
||||
@@ -186,7 +194,7 @@ class RunnerSupervisor:
|
||||
try:
|
||||
response = await self._read_with_error_check(timeout)
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.bind(user_facing=True).info(
|
||||
logger.bind(user_facing=True).error(
|
||||
f"Generation timed out during {'prefil' if timeout == prefil_timeout else 'decoding stage'}"
|
||||
)
|
||||
raise e
|
||||
@@ -219,16 +227,17 @@ class RunnerSupervisor:
|
||||
|
||||
async def _read_coro(self):
|
||||
while True:
|
||||
response: RunnerResponse | None = await supervisor_read_response(
|
||||
self.runner_process
|
||||
)
|
||||
if response is None:
|
||||
# Runner process died unexpectedly (C++ crash)
|
||||
try:
|
||||
response: RunnerResponse = await supervisor_read_response(
|
||||
self.runner_process
|
||||
)
|
||||
except EOFError:
|
||||
e = await self._raise_crashed()
|
||||
if e:
|
||||
raise e from EOFError
|
||||
# Runner process died unexpectedly (C++ crash)
|
||||
raise e from EOFError # TODO: Do we just want to create an error and put it on the read_queue here?
|
||||
else:
|
||||
break
|
||||
continue
|
||||
|
||||
match response:
|
||||
case PrintResponse():
|
||||
|
||||
@@ -68,6 +68,7 @@ def get_init_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
|
||||
|
||||
def get_prefil_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
return 30.0 # TODO: Proper prefil timeout calculation, but this requires knowing the number of tokens in the prompt.
|
||||
weights_size_gb = get_weights_size_kb(model_shard_meta) / (1024 * 1024)
|
||||
|
||||
tokens = 1000 # constant for now - the prompt is only tokenized in the device...
|
||||
|
||||
@@ -55,6 +55,7 @@ async def read_streaming_response(
|
||||
event.chunk, TokenChunk
|
||||
):
|
||||
response_string += event.chunk.text
|
||||
token_count += 1
|
||||
if event.chunk.finish_reason:
|
||||
finish_reason = event.chunk.finish_reason
|
||||
|
||||
|
||||
@@ -183,15 +183,15 @@ async def test_ttft(
|
||||
if not first_chunk_seen_1:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
_, seen_task_finished_1, response_string_1, _ = await read_streaming_response(
|
||||
_, seen_task_finished_1, response_string_1, token_count_1 = await read_streaming_response(
|
||||
global_events
|
||||
)
|
||||
# # total_time_1 = time.time() - task_created_time_1
|
||||
total_time_1 = time.time() - task_created_time_1
|
||||
|
||||
assert seen_task_finished_1
|
||||
|
||||
# Wait for first task to complete
|
||||
await asyncio.sleep(3.0)
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
# Second inference
|
||||
task2_params = ChatCompletionTaskParams(
|
||||
@@ -238,10 +238,10 @@ async def test_ttft(
|
||||
if not first_chunk_seen_2:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
_, seen_task_finished_2, response_string_2, _ = await read_streaming_response(
|
||||
_, seen_task_finished_2, response_string_2, token_count_2 = await read_streaming_response(
|
||||
global_events, filter_task=TASK_2_ID
|
||||
)
|
||||
# # total_time_2 = time.time() - task_created_time_2
|
||||
total_time_2 = time.time() - task_created_time_2
|
||||
|
||||
assert seen_task_finished_2
|
||||
assert time_to_first_token_1
|
||||
@@ -249,41 +249,41 @@ async def test_ttft(
|
||||
|
||||
# Calculate TPS metrics
|
||||
# Prompt is approximately 45 tokens according to user
|
||||
# prompt_tokens = 45
|
||||
prompt_tokens = 45
|
||||
|
||||
# # Prefill TPS = prompt tokens / time to first token
|
||||
# prefill_tps_1 = prompt_tokens / time_to_first_token_1 if time_to_first_token_1 > 0 else 0
|
||||
# prefill_tps_2 = prompt_tokens / time_to_first_token_2 if time_to_first_token_2 > 0 else 0
|
||||
# Prefill TPS = prompt tokens / time to first token
|
||||
prefill_tps_1 = prompt_tokens / time_to_first_token_1 if time_to_first_token_1 > 0 else 0
|
||||
prefill_tps_2 = prompt_tokens / time_to_first_token_2 if time_to_first_token_2 > 0 else 0
|
||||
|
||||
# # Generation TPS = generated tokens / generation time
|
||||
# # Generation time = total time - time to first token
|
||||
# generation_time_1 = total_time_1 - time_to_first_token_1
|
||||
# generation_time_2 = total_time_2 - time_to_first_token_2
|
||||
# generation_tps_1 = token_count_1 / generation_time_1 if generation_time_1 > 0 else 0
|
||||
# generation_tps_2 = token_count_2 / generation_time_2 if generation_time_2 > 0 else 0
|
||||
# Generation TPS = generated tokens / generation time
|
||||
# Generation time = total time - time to first token
|
||||
generation_time_1 = total_time_1 - time_to_first_token_1
|
||||
generation_time_2 = total_time_2 - time_to_first_token_2
|
||||
generation_tps_1 = token_count_1 / generation_time_1 if generation_time_1 > 0 else 0
|
||||
generation_tps_2 = token_count_2 / generation_time_2 if generation_time_2 > 0 else 0
|
||||
|
||||
# # Display time to first token profiling results
|
||||
# print("\n=== Time to First Token Profiling ===")
|
||||
# print(f"First inference ('{task1.task_params.messages[0].content}'):")
|
||||
# print(f" Time to first token: {time_to_first_token_1:.3f}s")
|
||||
# print(f" Total completion time: {total_time_1:.3f}s")
|
||||
# print(f" Tokens generated: {token_count_1}")
|
||||
# print(f" Response length: {len(response_string_1)} chars")
|
||||
# print(f" Prefill TPS: {prefill_tps_1:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_1:.3f}s)")
|
||||
# print(f" Generation TPS: {generation_tps_1:.1f} tokens/sec ({token_count_1} tokens / {generation_time_1:.3f}s)")
|
||||
# Display time to first token profiling results
|
||||
print("\n=== Time to First Token Profiling ===")
|
||||
print(f"First inference ('{task1.task_params.messages[0].content}'):")
|
||||
print(f" Time to first token: {time_to_first_token_1:.3f}s")
|
||||
print(f" Total completion time: {total_time_1:.3f}s")
|
||||
print(f" Tokens generated: {token_count_1}")
|
||||
print(f" Response length: {len(response_string_1)} chars")
|
||||
print(f" Prefill TPS: {prefill_tps_1:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_1:.3f}s)")
|
||||
print(f" Generation TPS: {generation_tps_1:.1f} tokens/sec ({token_count_1} tokens / {generation_time_1:.3f}s)")
|
||||
|
||||
# print(f"\nSecond inference ('{task2.task_params.messages[0].content}'):")
|
||||
# print(f" Time to first token: {time_to_first_token_2:.3f}s")
|
||||
# print(f" Total completion time: {total_time_2:.3f}s")
|
||||
# print(f" Tokens generated: {token_count_2}")
|
||||
# print(f" Response length: {len(response_string_2)} chars")
|
||||
# print(f" Prefill TPS: {prefill_tps_2:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_2:.3f}s)")
|
||||
# print(f" Generation TPS: {generation_tps_2:.1f} tokens/sec ({token_count_2} tokens / {generation_time_2:.3f}s)")
|
||||
print(f"\nSecond inference ('{task2.task_params.messages[0].content}'):")
|
||||
print(f" Time to first token: {time_to_first_token_2:.3f}s")
|
||||
print(f" Total completion time: {total_time_2:.3f}s")
|
||||
print(f" Tokens generated: {token_count_2}")
|
||||
print(f" Response length: {len(response_string_2)} chars")
|
||||
print(f" Prefill TPS: {prefill_tps_2:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_2:.3f}s)")
|
||||
print(f" Generation TPS: {generation_tps_2:.1f} tokens/sec ({token_count_2} tokens / {generation_time_2:.3f}s)")
|
||||
|
||||
# print("\nComparison:")
|
||||
# print(f" Second inference time to first token: {time_to_first_token_2/time_to_first_token_1:.2f}x the first")
|
||||
# print(f" Second inference prefill TPS: {prefill_tps_2/prefill_tps_1:.2f}x the first")
|
||||
# print(f" Second inference generation TPS: {generation_tps_2/generation_tps_1:.2f}x the first")
|
||||
print("\nComparison:")
|
||||
print(f" Second inference time to first token: {time_to_first_token_2/time_to_first_token_1:.2f}x the first")
|
||||
print(f" Second inference prefill TPS: {prefill_tps_2/prefill_tps_1:.2f}x the first")
|
||||
print(f" Second inference generation TPS: {generation_tps_2/generation_tps_1:.2f}x the first")
|
||||
|
||||
# Basic assertions to ensure responses make sense
|
||||
assert len(response_string_1) > 0
|
||||
|
||||
68
uv.lock
generated
68
uv.lock
generated
@@ -269,6 +269,7 @@ dependencies = [
|
||||
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "sqlmodel", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "textual", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "typeguard", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -313,6 +314,7 @@ requires-dist = [
|
||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.43" },
|
||||
{ name = "sqlmodel", specifier = ">=0.0.24" },
|
||||
{ name = "textual", specifier = ">=5.3.0" },
|
||||
{ name = "transformers", specifier = ">=4.55.2" },
|
||||
{ name = "typeguard", specifier = ">=4.4.4" },
|
||||
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
|
||||
@@ -567,6 +569,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/4a/4175a563579e884192ba6e81725fc0448b042024419be8d83aa8a80a3f44/jiter-0.10.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa96f2abba33dc77f79b4cf791840230375f9534e5fac927ccceb58c5e604a5", size = 354213, upload-time = "2025-05-18T19:04:41.894Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linkify-it-py"
|
||||
version = "2.0.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "uc-micro-py", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2a/ae/bb56c6828e4797ba5a4821eec7c43b8bf40f69cda4d4f5f8c8a2810ec96a/linkify-it-py-2.0.3.tar.gz", hash = "sha256:68cda27e162e9215c17d786649d1da0021a451bdc436ef9e0fa0ba5234b9b048", size = 27946, upload-time = "2024-02-04T14:48:04.179Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/1e/b832de447dee8b582cac175871d2f6c3d5077cc56d5575cadba1fd1cccfa/linkify_it_py-2.0.3-py3-none-any.whl", hash = "sha256:6bcbc417b0ac14323382aef5c5192c0075bf8a9d6b41820a2b66371eac6b6d79", size = 19820, upload-time = "2024-02-04T14:48:02.496Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "loguru"
|
||||
version = "0.7.3"
|
||||
@@ -588,6 +602,14 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
linkify = [
|
||||
{ name = "linkify-it-py", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
plugins = [
|
||||
{ name = "mdit-py-plugins", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markupsafe"
|
||||
version = "3.0.2"
|
||||
@@ -612,6 +634,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/80/0985960e4b89922cb5a0bac0ed39c5b96cbc1a536a99f30e8c220a996ed9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9", size = 24098, upload-time = "2024-10-18T15:21:40.813Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdit-py-plugins"
|
||||
version = "0.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markdown-it-py", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b2/fd/a756d36c0bfba5f6e39a1cdbdbfdd448dc02692467d83816dff4592a1ebc/mdit_py_plugins-0.5.0.tar.gz", hash = "sha256:f4918cb50119f50446560513a8e311d574ff6aaed72606ddae6d35716fe809c6", size = 44655, upload-time = "2025-08-11T07:25:49.083Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/86/dd6e5db36df29e76c7a7699123569a4a18c1623ce68d826ed96c62643cae/mdit_py_plugins-0.5.0-py3-none-any.whl", hash = "sha256:07a08422fc1936a5d26d146759e9155ea466e842f5ab2f7d2266dd084c8dab1f", size = 57205, upload-time = "2025-08-11T07:25:47.597Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
@@ -774,6 +808,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/78/f9/690a8600b93c332de3ab4a344a4ac34f00c8f104917061f779db6a918ed6/pathlib-1.0.1-py3-none-any.whl", hash = "sha256:f35f95ab8b0f59e6d354090350b44a80a80635d22efdedfa84c7ad1cf0a74147", size = 14363, upload-time = "2022-05-04T13:37:20.585Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "platformdirs"
|
||||
version = "4.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/23/e8/21db9c9987b0e728855bd57bff6984f67952bea55d6f75e055c46b5383e8/platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf", size = 21634, upload-time = "2025-08-26T14:32:04.268Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/40/4b/2028861e724d3bd36227adfa20d3fd24c3fc6d52032f4a93c133be5d17ce/platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85", size = 18654, upload-time = "2025-08-26T14:32:02.735Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
@@ -1122,6 +1165,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/fd/901cfa59aaa5b30a99e16876f11abe38b59a1a2c51ffb3d7142bb6089069/starlette-0.47.3-py3-none-any.whl", hash = "sha256:89c0778ca62a76b826101e7c709e70680a1699ca7da6b44d38eb0a7e61fe4b51", size = 72991, upload-time = "2025-08-24T13:36:40.887Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textual"
|
||||
version = "5.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markdown-it-py", extra = ["linkify", "plugins"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "platformdirs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pygments", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/ce/f0f938d33d9bebbf8629e0020be00c560ddfa90a23ebe727c2e5aa3f30cf/textual-5.3.0.tar.gz", hash = "sha256:1b6128b339adef2e298cc23ab4777180443240ece5c232f29b22960efd658d4d", size = 1557651, upload-time = "2025-08-07T12:36:50.342Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/00/2f/f7c8a533bee50fbf5bb37ffc1621e7b2cdd8c9a6301fc51faa35fa50b09d/textual-5.3.0-py3-none-any.whl", hash = "sha256:02a6abc065514c4e21f94e79aaecea1f78a28a85d11d7bfc64abf3392d399890", size = 702671, upload-time = "2025-08-07T12:36:48.272Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.21.4"
|
||||
@@ -1217,6 +1276,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uc-micro-py"
|
||||
version = "1.0.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/91/7a/146a99696aee0609e3712f2b44c6274566bc368dfe8375191278045186b8/uc-micro-py-1.0.3.tar.gz", hash = "sha256:d321b92cff673ec58027c04015fcaa8bb1e005478643ff4a500882eaab88c48a", size = 6043, upload-time = "2024-02-09T16:52:01.654Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/37/87/1f677586e8ac487e29672e4b17455758fce261de06a0d086167bb760361a/uc_micro_py-1.0.3-py3-none-any.whl", hash = "sha256:db1dffff340817673d7b466ec86114a9dc0e9d4d9b5ba229d9d60e5c12600cd5", size = 6229, upload-time = "2024-02-09T16:52:00.371Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.5.0"
|
||||
|
||||
Reference in New Issue
Block a user