mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Worker Exception & Timeout Refactor
Co-authored-by: Gelu Vrabie <gelu@exolabs.net> Co-authored-by: Alex Cheema <alexcheema123@gmail.com> Co-authored-by: Seth Howes <sethshowes@gmail.com>
This commit is contained in:
23
.github/workflows/build-macos-app.yml
vendored
23
.github/workflows/build-macos-app.yml
vendored
@@ -3,13 +3,14 @@ name: Build and Release Exo macOS App
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*' # Trigger only on version tags
|
||||
- 'v*' # Trigger on version tags
|
||||
branches:
|
||||
- main # Also build on main branch for testing
|
||||
- app-staging # Add app-staging for testing
|
||||
pull_request:
|
||||
branches:
|
||||
- main # Test builds on PRs
|
||||
- staging # Test builds on PRs to staging
|
||||
- main # Build on PRs to main
|
||||
|
||||
jobs:
|
||||
build-exov2-macos:
|
||||
@@ -20,18 +21,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Rust (nightly)
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
components: rustfmt, clippy
|
||||
default: true
|
||||
|
||||
- name: Set Rust toolchain override
|
||||
run: |
|
||||
rustup default nightly
|
||||
cd rust && rustup override set nightly
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
@@ -52,12 +41,6 @@ jobs:
|
||||
uv python install
|
||||
uv sync --locked --all-extras
|
||||
|
||||
- name: Build Rust Components
|
||||
env:
|
||||
RUSTFLAGS: "-A unused-imports -A dead-code -A unreachable-code"
|
||||
run: |
|
||||
just build-all
|
||||
|
||||
- name: Install Python Bindings
|
||||
run: |
|
||||
uv pip install dist/exo_pyo3_bindings-*.whl
|
||||
|
||||
43
configure_mlx.sh
Normal file
43
configure_mlx.sh
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Get the total memory in MB
|
||||
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
|
||||
|
||||
# Calculate 80% and TOTAL_MEM_GB-5GB in MB
|
||||
EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
|
||||
MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
|
||||
|
||||
# Calculate 70% and TOTAL_MEM_GB-8GB in MB
|
||||
SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
|
||||
MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
|
||||
|
||||
# Set WIRED_LIMIT_MB to higher value
|
||||
if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
|
||||
WIRED_LIMIT_MB=$EIGHTY_PERCENT
|
||||
else
|
||||
WIRED_LIMIT_MB=$MINUS_5GB
|
||||
fi
|
||||
|
||||
# Set WIRED_LWM_MB to higher value
|
||||
if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
|
||||
WIRED_LWM_MB=$SEVENTY_PERCENT
|
||||
else
|
||||
WIRED_LWM_MB=$MINUS_8GB
|
||||
fi
|
||||
|
||||
# Display the calculated values
|
||||
echo "Total memory: $TOTAL_MEM_MB MB"
|
||||
echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
|
||||
echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
|
||||
|
||||
# Apply the values with sysctl, but check if we're already root
|
||||
if [ "$EUID" -eq 0 ]; then
|
||||
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
|
||||
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
|
||||
else
|
||||
# Try without sudo first, fall back to sudo if needed
|
||||
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
|
||||
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
|
||||
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
|
||||
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
|
||||
fi
|
||||
@@ -483,25 +483,89 @@
|
||||
}
|
||||
|
||||
.model-select {
|
||||
background-color: var(--exo-medium-gray);
|
||||
background: linear-gradient(135deg, #2a2a2a 0%, #3c3c3c 50%, #2a2a2a 100%);
|
||||
color: var(--exo-light-gray);
|
||||
border: 1px solid var(--exo-light-gray);
|
||||
border-radius: 6px;
|
||||
padding: 10px 12px;
|
||||
font-size: 14px;
|
||||
border: 2px solid rgba(255, 215, 0, 0.2);
|
||||
border-radius: 12px;
|
||||
padding: 14px 20px 14px 16px;
|
||||
font-size: 15px;
|
||||
font-family: var(--font-family);
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.25s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
box-shadow:
|
||||
0 4px 12px rgba(0, 0, 0, 0.25),
|
||||
inset 0 1px 0 rgba(255, 255, 255, 0.12),
|
||||
inset 0 -1px 0 rgba(0, 0, 0, 0.1);
|
||||
position: relative;
|
||||
appearance: none;
|
||||
width: 100%;
|
||||
min-height: 48px;
|
||||
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23FFD700' d='M6 8.5L2.5 5h7z'/%3E%3C/svg%3E");
|
||||
background-position: calc(100% - 16px) center;
|
||||
background-size: 12px 12px;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
|
||||
.model-select:hover {
|
||||
background: linear-gradient(135deg, #363636 0%, #484848 50%, #363636 100%);
|
||||
border-color: rgba(255, 215, 0, 0.5);
|
||||
box-shadow:
|
||||
0 6px 20px rgba(0, 0, 0, 0.3),
|
||||
inset 0 1px 0 rgba(255, 255, 255, 0.15),
|
||||
inset 0 -1px 0 rgba(0, 0, 0, 0.1),
|
||||
0 0 0 1px rgba(255, 215, 0, 0.1);
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.model-select:focus {
|
||||
outline: none;
|
||||
border-color: var(--exo-yellow);
|
||||
box-shadow: 0 0 0 2px rgba(255, 215, 0, 0.2);
|
||||
box-shadow:
|
||||
0 0 0 4px rgba(255, 215, 0, 0.25),
|
||||
0 8px 24px rgba(0, 0, 0, 0.4),
|
||||
inset 0 1px 0 rgba(255, 255, 255, 0.2),
|
||||
inset 0 -1px 0 rgba(0, 0, 0, 0.1);
|
||||
background: linear-gradient(135deg, #404040 0%, #525252 50%, #404040 100%);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.model-select:active {
|
||||
transform: translateY(0);
|
||||
box-shadow:
|
||||
0 2px 8px rgba(0, 0, 0, 0.3),
|
||||
inset 0 1px 0 rgba(255, 255, 255, 0.1),
|
||||
inset 0 2px 6px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.model-select:disabled {
|
||||
background: linear-gradient(135deg, #1a1a1a 0%, #222222 100%);
|
||||
color: #555555;
|
||||
border-color: #333333;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
box-shadow: inset 0 2px 6px rgba(0, 0, 0, 0.4);
|
||||
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23555555' d='M6 8.5L2.5 5h7z'/%3E%3C/svg%3E");
|
||||
}
|
||||
|
||||
.model-select option {
|
||||
background-color: var(--exo-medium-gray);
|
||||
background-color: var(--exo-dark-gray);
|
||||
color: var(--exo-light-gray);
|
||||
padding: 12px 16px;
|
||||
border: none;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.model-select option:hover {
|
||||
background-color: var(--exo-medium-gray);
|
||||
color: var(--exo-yellow);
|
||||
}
|
||||
|
||||
.model-select option:checked {
|
||||
background-color: var(--exo-yellow);
|
||||
color: var(--exo-black);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.launch-button {
|
||||
@@ -576,6 +640,7 @@
|
||||
color: var(--exo-light-gray);
|
||||
font-style: italic;
|
||||
margin-top: 40px;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
|
||||
@@ -588,6 +653,13 @@
|
||||
|
||||
<!-- Sidebar -->
|
||||
<div class="sidebar" id="instancesSidebar">
|
||||
<div class="sidebar-header">
|
||||
<h3>Running Instances</h3>
|
||||
</div>
|
||||
<div class="sidebar-content" id="instancesList">
|
||||
<div class="no-instances">Loading instances...</div>
|
||||
</div>
|
||||
|
||||
<div class="sidebar-header">
|
||||
<h3>Launch Instance</h3>
|
||||
</div>
|
||||
@@ -601,13 +673,6 @@
|
||||
<div id="launchStatus" class="launch-status"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="sidebar-header">
|
||||
<h3>Running Instances</h3>
|
||||
</div>
|
||||
<div class="sidebar-content" id="instancesList">
|
||||
<div class="no-instances">Loading instances...</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="dashboard-header">
|
||||
|
||||
@@ -29,7 +29,6 @@ def mx_barrier():
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -130,3 +129,18 @@ async def apply_chat_template(
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def mlx_force_oom(size: int = 40000) -> None:
|
||||
"""
|
||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||
"""
|
||||
mx.set_default_device(mx.gpu) # type: ignore
|
||||
a = mx.random.uniform(shape=(size, size), dtype=mx.float32) # type: ignore
|
||||
b = mx.random.uniform(shape=(size, size), dtype=mx.float32) # type: ignore
|
||||
mx.eval(a, b) # type: ignore
|
||||
c = mx.matmul(a, b) # type: ignore
|
||||
d = mx.matmul(a, c) # type: ignore
|
||||
e = mx.matmul(b, c) # type: ignore
|
||||
f = mx.sigmoid(d + e) # type: ignore
|
||||
mx.eval(f) # type: ignore
|
||||
|
||||
@@ -32,6 +32,7 @@ from shared.types.events.commands import (
|
||||
CommandType,
|
||||
CreateInstanceCommand,
|
||||
DeleteInstanceCommand,
|
||||
TaskFinishedCommand,
|
||||
)
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.models import ModelMetadata
|
||||
@@ -177,6 +178,11 @@ class API:
|
||||
if event.chunk.finish_reason is not None:
|
||||
yield "data: [DONE]"
|
||||
finished = True
|
||||
|
||||
command = TaskFinishedCommand(
|
||||
command_id=command_id
|
||||
)
|
||||
self.command_buffer.append(command)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -14,11 +14,14 @@ from shared.apply import apply
|
||||
from shared.db.sqlite.config import EventLogConfig
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.common import CommandId, NodeId
|
||||
from shared.types.events import (
|
||||
Event,
|
||||
Heartbeat,
|
||||
InstanceDeleted,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TopologyEdgeDeleted,
|
||||
TopologyNodeCreated,
|
||||
)
|
||||
from shared.types.events.commands import (
|
||||
@@ -26,6 +29,7 @@ from shared.types.events.commands import (
|
||||
Command,
|
||||
CreateInstanceCommand,
|
||||
DeleteInstanceCommand,
|
||||
TaskFinishedCommand,
|
||||
)
|
||||
from shared.types.state import State
|
||||
from shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
|
||||
@@ -43,6 +47,7 @@ class Master:
|
||||
self.command_buffer = command_buffer
|
||||
self.global_events = global_events
|
||||
self.worker_events = worker_events
|
||||
self.command_task_mapping: dict[CommandId, TaskId] = {}
|
||||
self.forwarder_supervisor = ForwarderSupervisor(
|
||||
self.node_id,
|
||||
forwarder_binary_path=forwarder_binary_path,
|
||||
@@ -96,6 +101,8 @@ class Master:
|
||||
task_params=next_command.request_params
|
||||
)
|
||||
))
|
||||
|
||||
self.command_task_mapping[next_command.command_id] = task_id
|
||||
case DeleteInstanceCommand():
|
||||
placement = get_instance_placements(next_command, self.state.topology, self.state.instances)
|
||||
transition_events = get_transition_events(self.state.instances, placement)
|
||||
@@ -104,6 +111,11 @@ class Master:
|
||||
placement = get_instance_placements(next_command, self.state.topology, self.state.instances)
|
||||
transition_events = get_transition_events(self.state.instances, placement)
|
||||
next_events.extend(transition_events)
|
||||
case TaskFinishedCommand():
|
||||
next_events.append(TaskDeleted(
|
||||
task_id=self.command_task_mapping[next_command.command_id]
|
||||
))
|
||||
del self.command_task_mapping[next_command.command_id]
|
||||
|
||||
await self.event_log_for_writes.append_events(next_events, origin=self.node_id)
|
||||
# 2. get latest events
|
||||
@@ -119,6 +131,24 @@ class Master:
|
||||
self.state = apply(self.state, event_from_log)
|
||||
self.logger.info(f"state: {self.state.model_dump_json()}")
|
||||
|
||||
# TODO: This can be done in a better place. But for now, we use this to check if any running instances have been broken.
|
||||
write_events: list[Event] = []
|
||||
if any([isinstance(event_from_log.event, TopologyEdgeDeleted) for event_from_log in events]):
|
||||
connected_node_ids = set([x.node_id for x in self.state.topology.list_nodes()])
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
delete = False
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
delete = True
|
||||
break
|
||||
if delete:
|
||||
write_events.append(InstanceDeleted(
|
||||
instance_id=instance_id
|
||||
))
|
||||
|
||||
if write_events:
|
||||
await self.event_log_for_writes.append_events(events=write_events, origin=self.node_id)
|
||||
|
||||
async def run(self):
|
||||
self.state = await self._get_state_snapshot()
|
||||
|
||||
|
||||
@@ -41,7 +41,14 @@ def get_instance_placements(
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
selected_cycle = max(smallest_cycles, key=lambda cycle: sum(node.node_profile.memory.ram_available for node in cycle if node.node_profile is not None))
|
||||
selected_cycle = None
|
||||
for cycle in smallest_cycles:
|
||||
cycle_graph: Topology = topology.get_subgraph_from_nodes(cycle)
|
||||
if cycle_graph.is_thunderbolt_cycle(cycle):
|
||||
selected_cycle = cycle
|
||||
break
|
||||
if selected_cycle is None:
|
||||
selected_cycle = max(smallest_cycles, key=lambda cycle: sum(node.node_profile.memory.ram_available for node in cycle if node.node_profile is not None))
|
||||
|
||||
shard_assignments = get_shard_assignments(command.model_meta, selected_cycle)
|
||||
|
||||
|
||||
@@ -83,6 +83,10 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
if not cycles:
|
||||
return []
|
||||
|
||||
get_thunderbolt = False
|
||||
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
|
||||
get_thunderbolt = True
|
||||
|
||||
cycle = cycles[0]
|
||||
hosts: list[Host] = []
|
||||
for i in range(len(cycle)):
|
||||
@@ -92,8 +96,10 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (connection.local_node_id == current_node.node_id and
|
||||
connection.send_back_node_id == next_node.node_id):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
host = Host(
|
||||
ip=connection.send_back_multiaddr.ipv4_address,
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port
|
||||
)
|
||||
hosts.append(host)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
@@ -18,6 +19,10 @@ var (
|
||||
eventsDBPath string
|
||||
eventsDB *sql.DB
|
||||
eventsDBMu sync.Mutex
|
||||
|
||||
// Track connections to prevent duplicate events
|
||||
connectionTracker = make(map[string]bool)
|
||||
connTrackerMu sync.Mutex
|
||||
)
|
||||
|
||||
// SetEventsDBPath sets the path to the events database
|
||||
@@ -166,33 +171,44 @@ func (n *NotifeeHandler) Connected(net network.Network, conn network.Conn) {
|
||||
localAddr := conn.LocalMultiaddr()
|
||||
remoteAddr := conn.RemoteMultiaddr()
|
||||
|
||||
// Get the actual node IDs (not peer IDs)
|
||||
// Check if we've already processed this connection
|
||||
connKey := fmt.Sprintf("%s-%s", conn.LocalPeer(), remotePeer)
|
||||
connTrackerMu.Lock()
|
||||
if connectionTracker[connKey] {
|
||||
connTrackerMu.Unlock()
|
||||
log.Printf("Skipping duplicate connection event for %s", remotePeer)
|
||||
return
|
||||
}
|
||||
connectionTracker[connKey] = true
|
||||
connTrackerMu.Unlock()
|
||||
|
||||
// Get the local node ID
|
||||
localNodeID := GetNodeId()
|
||||
|
||||
// For remote node, we need to extract from peer ID or use a mapping
|
||||
// For now, we'll use the peer ID as a placeholder
|
||||
// TODO: Implement proper node ID mapping/discovery
|
||||
remoteNodeID := remotePeer.String()
|
||||
|
||||
// Create connection event
|
||||
event := &TopologyEdgeCreated{
|
||||
EventType: EventTypeTopologyEdgeCreated,
|
||||
EventID: uuid.New().String(),
|
||||
Edge: Connection{
|
||||
LocalNodeID: localNodeID,
|
||||
SendBackNodeID: remoteNodeID,
|
||||
LocalMultiaddr: parseMultiaddr(localAddr),
|
||||
SendBackMultiaddr: parseMultiaddr(remoteAddr),
|
||||
ConnectionProfile: nil, // TODO: Add connection profiling if needed
|
||||
},
|
||||
}
|
||||
|
||||
// Write event to database
|
||||
if err := writeEvent(EventTypeTopologyEdgeCreated, event); err != nil {
|
||||
log.Printf("Failed to write edge created event: %v", err)
|
||||
} else {
|
||||
log.Printf("Wrote edge created event: %s -> %s", localNodeID, remoteNodeID)
|
||||
}
|
||||
// Asynchronously exchange node IDs and write event
|
||||
go func() {
|
||||
mapper := GetNodeIDMapper()
|
||||
|
||||
// Add a small delay to ensure both sides are ready
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Exchange node IDs
|
||||
if err := mapper.ExchangeNodeID(remotePeer); err != nil {
|
||||
log.Printf("Failed to exchange node ID with %s: %v", remotePeer, err)
|
||||
// Don't write event if we can't get the node ID
|
||||
return
|
||||
}
|
||||
|
||||
// Get the actual remote node ID
|
||||
remoteNodeID, ok := mapper.GetNodeIDForPeer(remotePeer)
|
||||
if !ok {
|
||||
log.Printf("Node ID not found for peer %s after successful exchange", remotePeer)
|
||||
return
|
||||
}
|
||||
|
||||
// Write edge created event with correct node IDs
|
||||
writeEdgeCreatedEvent(localNodeID, remoteNodeID, localAddr, remoteAddr)
|
||||
}()
|
||||
}
|
||||
|
||||
// Disconnected is called when a connection is closed
|
||||
@@ -201,9 +217,27 @@ func (n *NotifeeHandler) Disconnected(net network.Network, conn network.Conn) {
|
||||
localAddr := conn.LocalMultiaddr()
|
||||
remoteAddr := conn.RemoteMultiaddr()
|
||||
|
||||
// Clear connection tracker
|
||||
connKey := fmt.Sprintf("%s-%s", conn.LocalPeer(), remotePeer)
|
||||
connTrackerMu.Lock()
|
||||
delete(connectionTracker, connKey)
|
||||
connTrackerMu.Unlock()
|
||||
|
||||
// Get the actual node IDs (not peer IDs)
|
||||
localNodeID := GetNodeId()
|
||||
remoteNodeID := remotePeer.String() // TODO: Implement proper node ID mapping
|
||||
|
||||
// Get the remote node ID from the mapper
|
||||
mapper := GetNodeIDMapper()
|
||||
remoteNodeID, ok := mapper.GetNodeIDForPeer(remotePeer)
|
||||
if !ok {
|
||||
// Don't write event if we don't have the node ID mapping
|
||||
log.Printf("No node ID mapping found for disconnected peer %s, skipping event", remotePeer)
|
||||
mapper.RemoveMapping(remotePeer)
|
||||
return
|
||||
}
|
||||
|
||||
// Clean up the mapping
|
||||
mapper.RemoveMapping(remotePeer)
|
||||
|
||||
// Create disconnection event
|
||||
event := &TopologyEdgeDeleted{
|
||||
@@ -253,6 +287,27 @@ func parseMultiaddr(ma multiaddr.Multiaddr) Multiaddr {
|
||||
return result
|
||||
}
|
||||
|
||||
// writeEdgeCreatedEvent writes a topology edge created event
|
||||
func writeEdgeCreatedEvent(localNodeID, remoteNodeID string, localAddr, remoteAddr multiaddr.Multiaddr) {
|
||||
event := &TopologyEdgeCreated{
|
||||
EventType: EventTypeTopologyEdgeCreated,
|
||||
EventID: uuid.New().String(),
|
||||
Edge: Connection{
|
||||
LocalNodeID: localNodeID,
|
||||
SendBackNodeID: remoteNodeID,
|
||||
LocalMultiaddr: parseMultiaddr(localAddr),
|
||||
SendBackMultiaddr: parseMultiaddr(remoteAddr),
|
||||
ConnectionProfile: nil,
|
||||
},
|
||||
}
|
||||
|
||||
if err := writeEvent(EventTypeTopologyEdgeCreated, event); err != nil {
|
||||
log.Printf("Failed to write edge created event: %v", err)
|
||||
} else {
|
||||
log.Printf("Wrote edge created event: %s -> %s", localNodeID, remoteNodeID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetNotifee returns a singleton instance of the notifee handler
|
||||
func GetNotifee() network.Notifiee {
|
||||
return &NotifeeHandler{}
|
||||
|
||||
@@ -433,6 +433,9 @@ func getNode(ctx context.Context) {
|
||||
|
||||
// Register event notifiee to track topology changes
|
||||
node.Network().Notify(GetNotifee())
|
||||
|
||||
// Set up node ID mapper
|
||||
GetNodeIDMapper().SetHost(node)
|
||||
|
||||
// Start a goroutine to periodically trigger mDNS discovery
|
||||
go periodicMDNSDiscovery()
|
||||
|
||||
185
networking/forwarder/src/node_id_exchange.go
Normal file
185
networking/forwarder/src/node_id_exchange.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
// NodeIDExchangeProtocol is the protocol ID for node ID exchange
|
||||
NodeIDExchangeProtocol = "/forwarder/nodeid/1.0.0"
|
||||
|
||||
// Exchange timeout - balanced for reliability
|
||||
exchangeTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// NodeIDMessage is the message format for node ID exchange
|
||||
type NodeIDMessage struct {
|
||||
NodeID string `json:"node_id"`
|
||||
}
|
||||
|
||||
// NodeIDMapper manages the mapping between peer IDs and node IDs
|
||||
type NodeIDMapper struct {
|
||||
mu sync.RWMutex
|
||||
peerToNode map[peer.ID]string
|
||||
nodeToPeer map[string]peer.ID
|
||||
host host.Host
|
||||
}
|
||||
|
||||
var (
|
||||
nodeIDMapper *NodeIDMapper
|
||||
mapperOnce sync.Once
|
||||
)
|
||||
|
||||
// GetNodeIDMapper returns the singleton NodeIDMapper instance
|
||||
func GetNodeIDMapper() *NodeIDMapper {
|
||||
mapperOnce.Do(func() {
|
||||
nodeIDMapper = &NodeIDMapper{
|
||||
peerToNode: make(map[peer.ID]string),
|
||||
nodeToPeer: make(map[string]peer.ID),
|
||||
}
|
||||
})
|
||||
return nodeIDMapper
|
||||
}
|
||||
|
||||
// SetHost sets the libp2p host for the mapper
|
||||
func (m *NodeIDMapper) SetHost(h host.Host) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.host = h
|
||||
|
||||
// Set up the stream handler for incoming node ID exchanges
|
||||
h.SetStreamHandler(NodeIDExchangeProtocol, m.handleNodeIDStream)
|
||||
}
|
||||
|
||||
// GetNodeIDForPeer returns the node ID for a given peer ID
|
||||
func (m *NodeIDMapper) GetNodeIDForPeer(peerID peer.ID) (string, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
nodeID, ok := m.peerToNode[peerID]
|
||||
return nodeID, ok
|
||||
}
|
||||
|
||||
// GetPeerIDForNode returns the peer ID for a given node ID
|
||||
func (m *NodeIDMapper) GetPeerIDForNode(nodeID string) (peer.ID, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
peerID, ok := m.nodeToPeer[nodeID]
|
||||
return peerID, ok
|
||||
}
|
||||
|
||||
// SetMapping sets the mapping between a peer ID and node ID
|
||||
func (m *NodeIDMapper) SetMapping(peerID peer.ID, nodeID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.peerToNode[peerID] = nodeID
|
||||
m.nodeToPeer[nodeID] = peerID
|
||||
log.Printf("Mapped peer %s to node %s", peerID, nodeID)
|
||||
}
|
||||
|
||||
// RemoveMapping removes the mapping for a peer
|
||||
func (m *NodeIDMapper) RemoveMapping(peerID peer.ID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if nodeID, ok := m.peerToNode[peerID]; ok {
|
||||
delete(m.peerToNode, peerID)
|
||||
delete(m.nodeToPeer, nodeID)
|
||||
log.Printf("Removed mapping for peer %s (was node %s)", peerID, nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
// ExchangeNodeID initiates a node ID exchange with a peer
|
||||
func (m *NodeIDMapper) ExchangeNodeID(peerID peer.ID) error {
|
||||
if m.host == nil {
|
||||
return fmt.Errorf("host not set")
|
||||
}
|
||||
|
||||
// Check if we already have the mapping
|
||||
if _, ok := m.GetNodeIDForPeer(peerID); ok {
|
||||
return nil // Already have the mapping
|
||||
}
|
||||
|
||||
// Try up to 3 times with exponential backoff
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Exponential backoff: 100ms, 200ms, 400ms
|
||||
time.Sleep(time.Duration(100<<uint(attempt-1)) * time.Millisecond)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), exchangeTimeout)
|
||||
|
||||
// Open a stream to the peer
|
||||
stream, err := m.host.NewStream(ctx, peerID, NodeIDExchangeProtocol)
|
||||
if err != nil {
|
||||
cancel()
|
||||
lastErr = fmt.Errorf("failed to open stream: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Send our node ID
|
||||
msg := NodeIDMessage{NodeID: GetNodeId()}
|
||||
encoder := json.NewEncoder(stream)
|
||||
if err := encoder.Encode(&msg); err != nil {
|
||||
stream.Close()
|
||||
cancel()
|
||||
lastErr = fmt.Errorf("failed to send node ID: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Read their node ID
|
||||
decoder := json.NewDecoder(bufio.NewReader(stream))
|
||||
var response NodeIDMessage
|
||||
if err := decoder.Decode(&response); err != nil {
|
||||
stream.Close()
|
||||
cancel()
|
||||
lastErr = fmt.Errorf("failed to read node ID: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stream.Close()
|
||||
cancel()
|
||||
|
||||
// Store the mapping
|
||||
m.SetMapping(peerID, response.NodeID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// handleNodeIDStream handles incoming node ID exchange requests
|
||||
func (m *NodeIDMapper) handleNodeIDStream(stream network.Stream) {
|
||||
defer stream.Close()
|
||||
|
||||
peerID := stream.Conn().RemotePeer()
|
||||
|
||||
// Read their node ID
|
||||
decoder := json.NewDecoder(bufio.NewReader(stream))
|
||||
var msg NodeIDMessage
|
||||
if err := decoder.Decode(&msg); err != nil {
|
||||
log.Printf("Failed to read node ID from %s: %v", peerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Store the mapping
|
||||
m.SetMapping(peerID, msg.NodeID)
|
||||
|
||||
// Send our node ID back
|
||||
response := NodeIDMessage{NodeID: GetNodeId()}
|
||||
encoder := json.NewEncoder(stream)
|
||||
if err := encoder.Encode(&response); err != nil {
|
||||
log.Printf("Failed to send node ID to %s: %v", peerID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
111
networking/forwarder/src/node_id_exchange_test.go
Normal file
111
networking/forwarder/src/node_id_exchange_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockNodeIDStreamHandler creates a stream handler that responds with a specific node ID
|
||||
func mockNodeIDStreamHandler(nodeID string) func(stream network.Stream) {
|
||||
return func(stream network.Stream) {
|
||||
defer stream.Close()
|
||||
|
||||
peerID := stream.Conn().RemotePeer()
|
||||
|
||||
// Read their node ID
|
||||
decoder := json.NewDecoder(bufio.NewReader(stream))
|
||||
var msg NodeIDMessage
|
||||
if err := decoder.Decode(&msg); err != nil {
|
||||
log.Printf("Failed to read node ID from %s: %v", peerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Send our node ID back
|
||||
response := NodeIDMessage{NodeID: nodeID}
|
||||
encoder := json.NewEncoder(stream)
|
||||
if err := encoder.Encode(&response); err != nil {
|
||||
log.Printf("Failed to send node ID to %s: %v", peerID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeIDExchange(t *testing.T) {
|
||||
// Create two test hosts
|
||||
h1, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
|
||||
require.NoError(t, err)
|
||||
defer h1.Close()
|
||||
|
||||
h2, err := libp2p.New(libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
|
||||
require.NoError(t, err)
|
||||
defer h2.Close()
|
||||
|
||||
// Set up node ID for host 1
|
||||
SetNodeId("node-1")
|
||||
mapper1 := GetNodeIDMapper()
|
||||
mapper1.SetHost(h1)
|
||||
|
||||
// Set up host 2 with a mock handler that responds with "node-2"
|
||||
h2.SetStreamHandler(NodeIDExchangeProtocol, mockNodeIDStreamHandler("node-2"))
|
||||
|
||||
// Connect the hosts
|
||||
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), 3600)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exchange node IDs
|
||||
err = mapper1.ExchangeNodeID(h2.ID())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the mapping on host 1
|
||||
nodeID, ok := mapper1.GetNodeIDForPeer(h2.ID())
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "node-2", nodeID)
|
||||
}
|
||||
|
||||
func TestNodeIDMapperOperations(t *testing.T) {
|
||||
mapper := &NodeIDMapper{
|
||||
peerToNode: make(map[peer.ID]string),
|
||||
nodeToPeer: make(map[string]peer.ID),
|
||||
}
|
||||
|
||||
// Test peer ID (simulated)
|
||||
peerID := peer.ID("test-peer-id")
|
||||
nodeID := "test-node-id"
|
||||
|
||||
// Set mapping
|
||||
mapper.SetMapping(peerID, nodeID)
|
||||
|
||||
// Verify forward mapping
|
||||
gotNodeID, ok := mapper.GetNodeIDForPeer(peerID)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, nodeID, gotNodeID)
|
||||
|
||||
// Verify reverse mapping
|
||||
gotPeerID, ok := mapper.GetPeerIDForNode(nodeID)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, peerID, gotPeerID)
|
||||
|
||||
// Remove mapping
|
||||
mapper.RemoveMapping(peerID)
|
||||
|
||||
// Verify removal
|
||||
_, ok = mapper.GetNodeIDForPeer(peerID)
|
||||
assert.False(t, ok)
|
||||
|
||||
_, ok = mapper.GetPeerIDForNode(nodeID)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
@@ -39,7 +39,8 @@ members = [
|
||||
"master",
|
||||
"worker",
|
||||
"shared",
|
||||
"engines/*"
|
||||
"engines/*",
|
||||
"scripts"
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
|
||||
|
||||
from worker.main import get_node_id
|
||||
from shared.types.common import NodeId
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager, EventLogConfig
|
||||
|
||||
async def main():
|
||||
node_id: NodeId = get_node_id()
|
||||
logger: Logger = Logger('worker_log')
|
||||
|
||||
event_log_manager: EventLogManager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
|
||||
events = await event_log_manager.global_events.get_events_since(0)
|
||||
|
||||
for wrapped_event in events:
|
||||
event = wrapped_event.event
|
||||
event_type = type(event).__name__.replace('_', ' ').title()
|
||||
attributes = ', '.join(f"{key}={value!r}" for key, value in vars(event).items())
|
||||
print(f"{event_type}: {attributes}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
0
scripts/README.md
Normal file
0
scripts/README.md
Normal file
30
scripts/pyproject.toml
Normal file
30
scripts/pyproject.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[project]
|
||||
name = "exo-scripts"
|
||||
version = "0.1.0"
|
||||
description = "Scripts for the Exo project"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"exo-shared",
|
||||
"huggingface_hub>=0.33.4",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build]
|
||||
clean = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = []
|
||||
include = ["*"]
|
||||
exclude = ["*.md", "pyproject.toml"]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
packages = []
|
||||
include = ["*"]
|
||||
exclude = ["*.md", "pyproject.toml"]
|
||||
516
scripts/read_events.py
Normal file
516
scripts/read_events.py
Normal file
@@ -0,0 +1,516 @@
|
||||
import asyncio
|
||||
import curses
|
||||
import time
|
||||
import json
|
||||
import argparse
|
||||
import textwrap
|
||||
import sys
|
||||
from logging import Logger
|
||||
from typing import List, Optional, Any, Sequence, Tuple
|
||||
|
||||
from shared.types.state import State
|
||||
from shared.apply import apply
|
||||
from shared.db.sqlite.event_log_manager import EventLogManager, EventLogConfig
|
||||
from shared.types.events.components import EventFromEventLog
|
||||
from shared.types.events import Event
|
||||
|
||||
# Globals
|
||||
logger: Logger = Logger('helper_log')
|
||||
event_log_manager: Optional[EventLogManager] = None
|
||||
worker_mode: bool = False
|
||||
|
||||
# Worker-related event types
|
||||
WORKER_EVENT_TYPES = {
|
||||
'TaskCreated', 'TaskStateUpdated', 'TaskFailed', 'TaskDeleted',
|
||||
'ChunkGenerated',
|
||||
'InstanceCreated', 'InstanceDeleted', 'InstanceActivated', 'InstanceDeactivated', 'InstanceReplacedAtomically',
|
||||
'RunnerStatusUpdated', 'RunnerDeleted'
|
||||
}
|
||||
|
||||
async def init_db() -> None:
|
||||
global event_log_manager
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
|
||||
async def get_events_since(since: int) -> Sequence[EventFromEventLog[Event]]:
|
||||
return await event_log_manager.global_events.get_events_since(since) # type: ignore[attr-defined, return-value]
|
||||
|
||||
async def load_all_events() -> List[EventFromEventLog[Event]]:
|
||||
events: List[EventFromEventLog[Event]] = []
|
||||
since = 0
|
||||
while True:
|
||||
new_events = await get_events_since(since)
|
||||
if not new_events:
|
||||
break
|
||||
events.extend(new_events)
|
||||
since += len(new_events)
|
||||
return events
|
||||
|
||||
def compute_states(events: List[EventFromEventLog[Event]]) -> List[State]:
|
||||
states: List[State] = [State()]
|
||||
state = states[0]
|
||||
for event in events:
|
||||
state = apply(state, event)
|
||||
states.append(state)
|
||||
return states
|
||||
|
||||
def print_event(event: EventFromEventLog[Event]) -> None:
|
||||
event_type_name = type(event.event).__name__
|
||||
event_type = event_type_name.replace('_', ' ').title()
|
||||
attributes = ', '.join(f"{key}={value!r}" for key, value in vars(event.event).items())
|
||||
print(f"[{event.idx_in_log}] {event_type}: {attributes}")
|
||||
|
||||
async def non_tui_mode() -> None:
|
||||
await init_db()
|
||||
events = await load_all_events()
|
||||
states = compute_states(events)
|
||||
final_state = states[-1]
|
||||
|
||||
if worker_mode:
|
||||
filtered_events = [e for e in events if type(e.event).__name__ in WORKER_EVENT_TYPES]
|
||||
events = filtered_events
|
||||
# Recompute states? But states are cumulative, so perhaps just print filtered events and full state, or filter state too.
|
||||
state_dict = json.loads(final_state.model_dump_json())
|
||||
filtered_state = {
|
||||
'node_status': state_dict.get('node_status', {}),
|
||||
'instances': state_dict.get('instances', {}),
|
||||
'runners': state_dict.get('runners', {}),
|
||||
'tasks': state_dict.get('tasks', {}),
|
||||
'last_event_applied_idx': state_dict.get('last_event_applied_idx', 0)
|
||||
}
|
||||
print("Final State (filtered):")
|
||||
print(json.dumps(filtered_state, indent=2))
|
||||
else:
|
||||
print("Final State:")
|
||||
print(final_state.model_dump_json(indent=2))
|
||||
|
||||
print("\nEvents:")
|
||||
for event in events:
|
||||
print_event(event)
|
||||
|
||||
async def update_events(wrapped_events: List[EventFromEventLog[Event]], states: List[State], filtered_indices: Optional[List[int]] = None) -> bool:
|
||||
last_since = len(wrapped_events)
|
||||
new_wrapped = await get_events_since(last_since)
|
||||
if new_wrapped:
|
||||
last_len = len(wrapped_events)
|
||||
for nw in new_wrapped:
|
||||
state = states[-1]
|
||||
new_state = apply(state, nw)
|
||||
states.append(new_state)
|
||||
wrapped_events.extend(new_wrapped)
|
||||
if filtered_indices is not None:
|
||||
for k in range(last_len, len(wrapped_events)):
|
||||
if type(wrapped_events[k].event).__name__ in WORKER_EVENT_TYPES:
|
||||
filtered_indices.append(k)
|
||||
return True
|
||||
return False
|
||||
|
||||
def draw_state(win: Any, state: State, height: int, width: int, worker_mode: bool, state_scroll: int) -> int:
|
||||
win.clear()
|
||||
state_dict = json.loads(state.model_dump_json())
|
||||
if worker_mode:
|
||||
filtered_state = {
|
||||
'node_status': state_dict.get('node_status', {}),
|
||||
'instances': state_dict.get('instances', {}),
|
||||
'runners': state_dict.get('runners', {}),
|
||||
'tasks': state_dict.get('tasks', {}),
|
||||
'last_event_applied_idx': state_dict.get('last_event_applied_idx', 0)
|
||||
}
|
||||
state_pretty = json.dumps(filtered_state, indent=2)
|
||||
else:
|
||||
state_pretty = json.dumps(state_dict, indent=2)
|
||||
lines = state_pretty.split('\n')
|
||||
max_scroll = max(0, len(lines) - height)
|
||||
current_scroll = min(state_scroll, max_scroll)
|
||||
for i in range(height):
|
||||
line_idx = current_scroll + i
|
||||
if line_idx >= len(lines):
|
||||
break
|
||||
line = lines[line_idx]
|
||||
y = i
|
||||
x = 0
|
||||
leading_spaces = len(line) - len(line.lstrip())
|
||||
win.addstr(y, x, ' ' * leading_spaces)
|
||||
x += leading_spaces
|
||||
stripped = line.lstrip()
|
||||
if stripped.startswith('"'):
|
||||
end_key = stripped.find('": ')
|
||||
if end_key != -1:
|
||||
key_str = stripped[:end_key + 3] # include ":
|
||||
win.addstr(y, x, key_str, curses.color_pair(3))
|
||||
x += len(key_str)
|
||||
value_str = stripped[end_key + 3:]
|
||||
if value_str.startswith('"'):
|
||||
color = 2
|
||||
elif value_str.replace('.', '', 1).isdigit() or (value_str.startswith('-') and value_str[1:].replace('.', '', 1).isdigit()):
|
||||
color = 4
|
||||
elif value_str in ['true', 'false', 'null']:
|
||||
color = 5
|
||||
elif value_str.startswith('{') or value_str.startswith('[') or value_str.startswith('}') or value_str.startswith(']'):
|
||||
color = 0
|
||||
else:
|
||||
color = 0
|
||||
win.addstr(y, x, value_str, curses.color_pair(color))
|
||||
else:
|
||||
win.addstr(y, x, stripped)
|
||||
else:
|
||||
win.addstr(y, x, stripped)
|
||||
win.refresh()
|
||||
return current_scroll
|
||||
|
||||
def get_event_pairs(event: EventFromEventLog[Event]) -> List[Tuple[str, int]]:
|
||||
pairs: List[Tuple[str, int]] = []
|
||||
idx_str = f"[{event.idx_in_log}] "
|
||||
pairs.append((idx_str, 5))
|
||||
event_type_name = type(event.event).__name__
|
||||
event_type = event_type_name.replace('_', ' ').title()
|
||||
pairs.append((event_type, 1))
|
||||
pairs.append((": ", 0))
|
||||
attrs = vars(event.event)
|
||||
first = True
|
||||
for key, value in attrs.items():
|
||||
if not first:
|
||||
pairs.append((", ", 0))
|
||||
first = False
|
||||
pairs.append((key, 3))
|
||||
pairs.append(("=", 0))
|
||||
v_str = repr(value)
|
||||
if isinstance(value, str):
|
||||
color = 2
|
||||
elif isinstance(value, (int, float)):
|
||||
color = 4
|
||||
elif isinstance(value, bool):
|
||||
color = 5
|
||||
else:
|
||||
color = 6
|
||||
pairs.append((v_str, color))
|
||||
return pairs
|
||||
|
||||
def calculate_event_lines(pairs: List[Tuple[str, int]], win_width: int, subsequent_indent: int) -> int:
|
||||
lines = 1
|
||||
x = 0
|
||||
for text, _ in pairs:
|
||||
i = 0
|
||||
while i < len(text):
|
||||
remaining = win_width - x
|
||||
part_len = min(len(text) - i, remaining)
|
||||
i += part_len
|
||||
x += part_len
|
||||
if i < len(text):
|
||||
lines += 1
|
||||
x = subsequent_indent
|
||||
return lines
|
||||
|
||||
def render_event(win: Any, start_y: int, pairs: List[Tuple[str, int]], is_bold: bool, win_width: int, subsequent_indent: int) -> int:
|
||||
y = start_y
|
||||
x = 0
|
||||
for text, color in pairs:
|
||||
attr = curses.color_pair(color) | (curses.A_BOLD if is_bold else 0)
|
||||
i = 0
|
||||
while i < len(text):
|
||||
remaining = win_width - x
|
||||
part_len = min(len(text) - i, remaining)
|
||||
part = text[i:i + part_len]
|
||||
try:
|
||||
win.addstr(y, x, part, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
i += part_len
|
||||
x += part_len
|
||||
if i < len(text):
|
||||
y += 1
|
||||
if y >= win.getmaxyx()[0]:
|
||||
return y
|
||||
x = subsequent_indent
|
||||
if x > 0:
|
||||
y += 1
|
||||
return y
|
||||
|
||||
def draw_events(win: Any, events_list: List[EventFromEventLog[Event]], current_events: int, height: int) -> None:
|
||||
win.clear()
|
||||
if len(events_list) == 0:
|
||||
win.addstr(0, 0, "No events")
|
||||
win.refresh()
|
||||
return
|
||||
win_width = win.getmaxyx()[1]
|
||||
current_event = events_list[current_events]
|
||||
current_pairs = get_event_pairs(current_event)
|
||||
subsequent_indent = len(f"[{current_event.idx_in_log}] ")
|
||||
lines_current = calculate_event_lines(current_pairs, win_width, subsequent_indent)
|
||||
if lines_current > height:
|
||||
render_event(win, 0, current_pairs, True, win_width, subsequent_indent)
|
||||
win.refresh()
|
||||
return
|
||||
|
||||
target_above = (height - lines_current) // 2
|
||||
target_below = height - lines_current - target_above
|
||||
|
||||
# Collect previous events
|
||||
prev_events: List[int] = []
|
||||
remaining = target_above
|
||||
i = current_events - 1
|
||||
while i >= 0 and remaining > 0:
|
||||
event = events_list[i]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
lines = calculate_event_lines(pairs, win_width, indent)
|
||||
if lines <= remaining:
|
||||
remaining -= lines
|
||||
prev_events.append(i)
|
||||
i -= 1
|
||||
else:
|
||||
break
|
||||
prev_events.reverse()
|
||||
|
||||
# Collect next events
|
||||
next_events: List[int] = []
|
||||
remaining = target_below
|
||||
j = current_events + 1
|
||||
while j < len(events_list) and remaining > 0:
|
||||
event = events_list[j]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
lines = calculate_event_lines(pairs, win_width, indent)
|
||||
if lines <= remaining:
|
||||
remaining -= lines
|
||||
next_events.append(j)
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# Calculate total lines
|
||||
total_lines = lines_current
|
||||
for idx in prev_events:
|
||||
event = events_list[idx]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
total_lines += calculate_event_lines(pairs, win_width, indent)
|
||||
for idx in next_events:
|
||||
event = events_list[idx]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
total_lines += calculate_event_lines(pairs, win_width, indent)
|
||||
|
||||
padding = (height - total_lines) // 2 if total_lines < height else 0
|
||||
|
||||
y = padding
|
||||
# Draw prev
|
||||
for idx in prev_events:
|
||||
event = events_list[idx]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
y = render_event(win, y, pairs, False, win_width, indent)
|
||||
|
||||
# Draw current
|
||||
y = render_event(win, y, current_pairs, True, win_width, subsequent_indent)
|
||||
|
||||
# Draw next
|
||||
for idx in next_events:
|
||||
event = events_list[idx]
|
||||
pairs = get_event_pairs(event)
|
||||
indent = len(f"[{event.idx_in_log}] ")
|
||||
y = render_event(win, y, pairs, False, win_width, indent)
|
||||
|
||||
win.refresh()
|
||||
|
||||
def draw_status(win: Any, realtime: bool, current: int, total_events: int) -> None:
|
||||
win.clear()
|
||||
mode = "Realtime" if realtime else "Timetravel"
|
||||
win.addstr(0, 0, f"Mode: {mode} | Current event: {current} / {total_events} | Arrows: navigate events, [/]: scroll state, g: goto, r: toggle realtime, q: quit")
|
||||
win.refresh()
|
||||
|
||||
def get_input(stdscr: Any, prompt: str) -> str:
|
||||
curses.echo()
|
||||
stdscr.addstr(0, 0, prompt)
|
||||
stdscr.refresh()
|
||||
input_str = stdscr.getstr(0, len(prompt), 20).decode('utf-8')
|
||||
curses.noecho()
|
||||
return input_str
|
||||
|
||||
def get_key(win: Any) -> Any:
|
||||
ch = win.getch()
|
||||
if ch == -1:
|
||||
return -1
|
||||
if ch == 27:
|
||||
ch2 = win.getch()
|
||||
if ch2 == -1:
|
||||
return 27
|
||||
if ch2 == 91:
|
||||
ch3 = win.getch()
|
||||
if ch3 == -1:
|
||||
return -1
|
||||
if ch3 == 65:
|
||||
return curses.KEY_UP
|
||||
if ch3 == 66:
|
||||
return curses.KEY_DOWN
|
||||
if ch3 == 53:
|
||||
ch4 = win.getch()
|
||||
if ch4 == 126:
|
||||
return curses.KEY_PPAGE
|
||||
if ch3 == 54:
|
||||
ch4 = win.getch()
|
||||
if ch4 == 126:
|
||||
return curses.KEY_NPAGE
|
||||
if ch3 == 49:
|
||||
ch4 = win.getch()
|
||||
if ch4 == -1:
|
||||
return -1
|
||||
if ch4 == 59:
|
||||
ch5 = win.getch()
|
||||
if ch5 == -1:
|
||||
return -1
|
||||
if ch5 == 53:
|
||||
ch6 = win.getch()
|
||||
if ch6 == -1:
|
||||
return -1
|
||||
if ch6 == 65:
|
||||
return 'CTRL_UP'
|
||||
if ch6 == 66:
|
||||
return 'CTRL_DOWN'
|
||||
return ch
|
||||
|
||||
def tui(stdscr: Any) -> None:
|
||||
curses.start_color()
|
||||
curses.init_pair(1, curses.COLOR_BLUE, curses.COLOR_BLACK)
|
||||
curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK)
|
||||
curses.init_pair(3, curses.COLOR_MAGENTA, curses.COLOR_BLACK)
|
||||
curses.init_pair(4, curses.COLOR_YELLOW, curses.COLOR_BLACK)
|
||||
curses.init_pair(5, curses.COLOR_CYAN, curses.COLOR_BLACK)
|
||||
curses.init_pair(6, curses.COLOR_WHITE, curses.COLOR_BLACK)
|
||||
curses.use_default_colors()
|
||||
stdscr.timeout(100)
|
||||
curses.curs_set(0)
|
||||
|
||||
wrapped_events: List[EventFromEventLog[Event]] = []
|
||||
states: List[State] = [State()]
|
||||
asyncio.run(init_db())
|
||||
asyncio.run(update_events(wrapped_events, states)) # Initial load
|
||||
|
||||
filtered_indices: Optional[List[int]] = None
|
||||
current_filtered: int = -1
|
||||
current: int = -1
|
||||
if worker_mode:
|
||||
filtered_indices = [i for i in range(len(wrapped_events)) if type(wrapped_events[i].event).__name__ in WORKER_EVENT_TYPES]
|
||||
current_filtered = len(filtered_indices) - 1 if filtered_indices else -1
|
||||
else:
|
||||
current = len(wrapped_events) - 1 if wrapped_events else -1
|
||||
|
||||
realtime: bool = False
|
||||
last_update: float = time.time()
|
||||
update_interval: float = 1.0
|
||||
state_scroll: int = 0
|
||||
|
||||
while True:
|
||||
height, width = stdscr.getmaxyx()
|
||||
status_height = 1
|
||||
pane_height = height - status_height
|
||||
pane_width = width // 2
|
||||
|
||||
state_win = curses.newwin(pane_height, pane_width, 0, 0)
|
||||
events_win = curses.newwin(pane_height, width - pane_width, 0, pane_width)
|
||||
status_win = curses.newwin(status_height, width, pane_height, 0)
|
||||
|
||||
if worker_mode:
|
||||
assert filtered_indices is not None
|
||||
current_original = filtered_indices[current_filtered] if current_filtered >= 0 else -1
|
||||
events_list = [wrapped_events[i] for i in filtered_indices]
|
||||
current_events = current_filtered
|
||||
else:
|
||||
current_original = current
|
||||
events_list = wrapped_events
|
||||
current_events = current
|
||||
|
||||
state_idx = current_original + 1 if current_original >= 0 else 0
|
||||
state_scroll = draw_state(state_win, states[state_idx], pane_height, pane_width, worker_mode, state_scroll)
|
||||
draw_events(events_win, events_list, current_events, pane_height)
|
||||
total_events = len(wrapped_events) - 1 if wrapped_events else -1
|
||||
draw_status(status_win, realtime, current_original if worker_mode else current, total_events)
|
||||
|
||||
key = get_key(stdscr)
|
||||
if key != -1:
|
||||
if key == curses.KEY_UP:
|
||||
if worker_mode and current_filtered > 0:
|
||||
current_filtered -= 1
|
||||
elif not worker_mode and current > 0:
|
||||
current -= 1
|
||||
elif key == 'CTRL_UP':
|
||||
if worker_mode:
|
||||
current_filtered = max(0, current_filtered - 5)
|
||||
else:
|
||||
current = max(0, current - 5)
|
||||
elif key == curses.KEY_DOWN:
|
||||
if worker_mode and current_filtered < len(filtered_indices) - 1: # type: ignore[arg-type]
|
||||
current_filtered += 1
|
||||
elif not worker_mode and current < len(wrapped_events) - 1:
|
||||
current += 1
|
||||
elif key == 'CTRL_DOWN':
|
||||
if worker_mode:
|
||||
current_filtered = min(len(filtered_indices) - 1, current_filtered + 5) # type: ignore[arg-type]
|
||||
else:
|
||||
current = min(len(wrapped_events) - 1, current + 5)
|
||||
elif key == ord('['):
|
||||
state_scroll = max(0, state_scroll - pane_height // 2)
|
||||
elif key == ord(']'):
|
||||
state_scroll += pane_height // 2 # clamped in draw_state
|
||||
elif key == ord('q'):
|
||||
break
|
||||
elif key == ord('r'):
|
||||
realtime = not realtime
|
||||
if realtime:
|
||||
if worker_mode:
|
||||
current_filtered = len(filtered_indices) - 1 if filtered_indices else -1 # type: ignore[arg-type]
|
||||
else:
|
||||
current = len(wrapped_events) - 1 if wrapped_events else -1
|
||||
state_scroll = 0
|
||||
elif key == ord('g'):
|
||||
stdscr.timeout(-1) # block for input
|
||||
input_str = get_input(status_win, "Go to event: ")
|
||||
try:
|
||||
goto = int(input_str)
|
||||
if worker_mode:
|
||||
assert filtered_indices is not None
|
||||
for i, orig in enumerate(filtered_indices):
|
||||
if wrapped_events[orig].idx_in_log == goto:
|
||||
current_filtered = i
|
||||
state_scroll = 0
|
||||
break
|
||||
else:
|
||||
for i in range(len(wrapped_events)):
|
||||
if wrapped_events[i].idx_in_log == goto:
|
||||
current = i
|
||||
state_scroll = 0
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
stdscr.timeout(100)
|
||||
status_win.clear()
|
||||
status_win.refresh()
|
||||
|
||||
if realtime and time.time() - last_update > update_interval:
|
||||
updated = asyncio.run(update_events(wrapped_events, states, filtered_indices if worker_mode else None))
|
||||
if updated:
|
||||
if worker_mode:
|
||||
current_filtered = len(filtered_indices) - 1 # type: ignore[arg-type]
|
||||
else:
|
||||
current = len(wrapped_events) - 1
|
||||
state_scroll = 0
|
||||
last_update = time.time()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Read and display events from the event log')
|
||||
parser.add_argument('--worker', action='store_true', help='Only show worker-related events (task, streaming, instance, runner status)')
|
||||
args = parser.parse_args()
|
||||
|
||||
worker_mode = args.worker
|
||||
|
||||
if not sys.stdout.isatty():
|
||||
asyncio.run(non_tui_mode())
|
||||
else:
|
||||
try:
|
||||
curses.wrapper(tui)
|
||||
except curses.error as e:
|
||||
if "could not find terminal" in str(e):
|
||||
print("Error: Could not find terminal. Falling back to non-TUI mode.")
|
||||
asyncio.run(non_tui_mode())
|
||||
else:
|
||||
raise
|
||||
12
scripts/test_download.py
Normal file
12
scripts/test_download.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from worker.download.download_utils import *
|
||||
|
||||
async def main():
|
||||
meta = await file_meta(
|
||||
'mlx-community/DeepSeek-R1-4bit',
|
||||
revision='main',
|
||||
path='config.json',
|
||||
redirected_location=None,
|
||||
)
|
||||
print(meta)
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -140,10 +140,10 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
def apply_node_performance_measured(event: NodePerformanceMeasured, state: State) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {**state.node_profiles, event.node_id: event.node_profile}
|
||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
||||
if not state.topology.contains_node(event.node_id):
|
||||
# TODO: figure out why this is happening in the first place
|
||||
return state
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_node(event.node_id):
|
||||
# TODO: figure out why this is happening in the first place
|
||||
topology.add_node(Node(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@@ -164,13 +164,6 @@ def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> Sta
|
||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_connection(event.edge)
|
||||
opposite_edge = Connection(
|
||||
local_node_id=event.edge.send_back_node_id,
|
||||
send_back_node_id=event.edge.local_node_id,
|
||||
local_multiaddr=event.edge.send_back_multiaddr,
|
||||
send_back_multiaddr=event.edge.local_multiaddr
|
||||
)
|
||||
topology.add_connection(opposite_edge)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@event_apply.register(TopologyEdgeReplacedAtomically)
|
||||
|
||||
@@ -20,6 +20,10 @@ EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"
|
||||
LIBP2P_WORKER_EVENTS_TOPIC = "worker_events"
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
|
||||
|
||||
# lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip.
|
||||
LB_TFLOPS = 2.3
|
||||
LB_MEMBW_GBPS = 68
|
||||
LB_DISK_GBPS = 1.5
|
||||
|
||||
# little helper function to get the name of the module that raised the error
|
||||
def get_caller_module_name() -> str:
|
||||
|
||||
@@ -14,7 +14,20 @@ class ModelCard(BaseModel):
|
||||
metadata: ModelMetadata
|
||||
|
||||
|
||||
MODEL_CARDS = {
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"deepseek-v3-0324": ModelCard(
|
||||
short_id="deepseek-v3-0324",
|
||||
model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
name="DeepSeek V3 fp8",
|
||||
description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
pretty_name="DeepSeek V3 fp8",
|
||||
storage_size_kilobytes=754998771712//1024,
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
"llama-3.3": ModelCard(
|
||||
short_id="llama-3.3",
|
||||
model_id="mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
from huggingface_hub import model_info
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -8,7 +9,7 @@ from shared.types.models import ModelMetadata
|
||||
from worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
ensure_exo_tmp,
|
||||
ensure_models_dir,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,14 +44,16 @@ class ConfigData(BaseModel):
|
||||
|
||||
async def get_config_data(model_id: str) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
target_dir = (await ensure_exo_tmp())/model_id.replace("/", "--")
|
||||
target_dir = (await ensure_models_dir())/str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
config_path = await download_file_with_retry(model_id, "main", "config.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes}"))
|
||||
async with aiofiles.open(config_path, 'r') as f:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
async def get_safetensors_size(model_id: str) -> int:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
target_dir = (await ensure_exo_tmp())/model_id.replace("/", "--")
|
||||
target_dir = (await ensure_models_dir())/str(model_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_path = await download_file_with_retry(model_id, "main", "model.safetensors.index.json", target_dir, lambda curr_bytes, total_bytes: print(f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes}"))
|
||||
async with aiofiles.open(index_path, 'r') as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
@@ -161,6 +161,22 @@ class Topology(TopologyProto):
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: list[Node]) -> bool:
|
||||
node_idxs = [node.node_id for node in cycle]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
continue
|
||||
has_tb = False
|
||||
for edge in self._graph.get_all_edge_data(rid, neighbor_rid):
|
||||
if edge.is_thunderbolt():
|
||||
has_tb = True
|
||||
break
|
||||
if not has_tb:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _is_bridge(self, connection: Connection) -> bool:
|
||||
"""Check if removing this connection will orphan any nodes from the master."""
|
||||
if self.master_node_id is None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from typing import Any, Self
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -29,7 +29,7 @@ class CommandId(ID):
|
||||
|
||||
|
||||
class Host(BaseModel):
|
||||
ip: IPv4Address
|
||||
ip: IPv4Address | IPv6Address
|
||||
port: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -16,6 +16,7 @@ class CommandType(str, Enum):
|
||||
CHAT_COMPLETION = "CHAT_COMPLETION"
|
||||
CREATE_INSTANCE = "CREATE_INSTANCE"
|
||||
DELETE_INSTANCE = "DELETE_INSTANCE"
|
||||
TASK_FINISHED = "TASK_FINISHED"
|
||||
|
||||
|
||||
class _BaseCommand[T: CommandType](BaseModel):
|
||||
@@ -39,8 +40,12 @@ class DeleteInstanceCommand(_BaseCommand[CommandType.DELETE_INSTANCE]):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskFinishedCommand(_BaseCommand[CommandType.TASK_FINISHED]):
|
||||
command_type: Literal[CommandType.TASK_FINISHED] = CommandType.TASK_FINISHED
|
||||
|
||||
|
||||
Command = Annotated[
|
||||
ChatCompletionCommand | CreateInstanceCommand | DeleteInstanceCommand,
|
||||
ChatCompletionCommand | CreateInstanceCommand | DeleteInstanceCommand | TaskFinishedCommand,
|
||||
Field(discriminator="command_type")
|
||||
]
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, computed_field, field_serializer, field_validator
|
||||
@@ -25,6 +25,20 @@ class Multiaddr(BaseModel):
|
||||
return v
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def address_type(self) -> str:
|
||||
for pattern in self.PATTERNS:
|
||||
if re.match(pattern, self.address):
|
||||
return pattern.split('/')[1]
|
||||
raise ValueError(f"Invalid multiaddr format: {self.address}")
|
||||
|
||||
@property
|
||||
def ipv6_address(self) -> IPv6Address:
|
||||
match = re.match(r'^/ip6/([0-9a-fA-F:]+)', self.address)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid multiaddr format: {self.address}. Expected format like /ip6/::1/tcp/4001")
|
||||
return IPv6Address(match.group(1))
|
||||
|
||||
@property
|
||||
def ipv4_address(self) -> IPv4Address:
|
||||
match = re.match(r'^/ip4/(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})', self.address)
|
||||
@@ -32,11 +46,15 @@ class Multiaddr(BaseModel):
|
||||
raise ValueError(f"Invalid multiaddr format: {self.address}. Expected format like /ip4/127.0.0.1/tcp/4001")
|
||||
return IPv4Address(match.group(1))
|
||||
|
||||
@field_serializer("ipv4_address")
|
||||
def serialize_ipv4_address(self, value: IPv4Address) -> str:
|
||||
@computed_field
|
||||
@property
|
||||
def ip_address(self) -> IPv4Address | IPv6Address:
|
||||
return self.ipv4_address if self.address_type == 'ip4' else self.ipv6_address
|
||||
|
||||
@field_serializer("ip_address")
|
||||
def serialize_ipv4_address(self, value: IPv4Address | IPv6Address) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def port(self) -> int:
|
||||
|
||||
@@ -22,8 +22,8 @@ class Connection(BaseModel):
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.local_multiaddr.ipv4_address,
|
||||
self.send_back_multiaddr.ipv4_address,
|
||||
self.local_multiaddr.ip_address,
|
||||
self.send_back_multiaddr.ip_address,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -33,9 +33,12 @@ class Connection(BaseModel):
|
||||
return (
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.local_multiaddr.ipv4_address == other.local_multiaddr.ipv4_address
|
||||
and self.send_back_multiaddr.ipv4_address == other.send_back_multiaddr.ipv4_address
|
||||
and self.local_multiaddr.ip_address == other.local_multiaddr.ip_address
|
||||
and self.send_back_multiaddr.ip_address == other.send_back_multiaddr.ip_address
|
||||
)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.local_multiaddr.ip_address).startswith('169.254') and str(self.send_back_multiaddr.ip_address).startswith('169.254')
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from shared.types.common import ID
|
||||
|
||||
@@ -14,3 +15,12 @@ class RunnerId(ID):
|
||||
class NodeStatus(str, Enum):
|
||||
Idle = "Idle"
|
||||
Running = "Running"
|
||||
|
||||
class RunnerError(Exception):
|
||||
"""Exception raised when the runner process encounters an error."""
|
||||
|
||||
def __init__(self, error_type: str, error_message: str, traceback: Optional[str] = None):
|
||||
self.error_type = error_type
|
||||
self.error_message = error_message
|
||||
self.traceback = traceback
|
||||
super().__init__(f"{error_type}: {error_message}")
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Generic, Literal, TypeVar
|
||||
from typing import Annotated, Generic, Literal, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
@@ -24,6 +24,11 @@ class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]):
|
||||
partition_strategy: PartitionStrategyT
|
||||
device_rank: int
|
||||
world_size: int
|
||||
|
||||
# Error handling; equivalent to monkey-patch, but we can't monkey-patch runner.py
|
||||
# This is kinda annoying because it allocates memory in the ShardMetadata object. Can be rethought after Shanghai.
|
||||
immediate_exception: bool = False
|
||||
should_timeout: Optional[float] = None
|
||||
|
||||
|
||||
class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]):
|
||||
|
||||
31
uv.lock
generated
31
uv.lock
generated
@@ -15,6 +15,7 @@ members = [
|
||||
"exo",
|
||||
"exo-engine-mlx",
|
||||
"exo-master",
|
||||
"exo-scripts",
|
||||
"exo-shared",
|
||||
"exo-worker",
|
||||
]
|
||||
@@ -303,6 +304,21 @@ requires-dist = [
|
||||
{ name = "uvicorn", specifier = ">=0.35.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exo-scripts"
|
||||
version = "0.1.0"
|
||||
source = { editable = "scripts" }
|
||||
dependencies = [
|
||||
{ name = "exo-shared", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "exo-shared", editable = "shared" },
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exo-shared"
|
||||
version = "0.1.0"
|
||||
@@ -365,6 +381,7 @@ dependencies = [
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
@@ -373,6 +390,7 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "mlx", specifier = "==0.26.3" },
|
||||
{ name = "mlx-lm", specifier = ">=0.25.3" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -840,6 +858,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/af/ab3c51ab7507a7325e98ffe691d9495ee3d3aa5f589afad65ec920d39821/protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e", size = 168724, upload-time = "2025-05-28T19:25:53.926Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
version = "7.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2a/80/336820c1ad9286a4ded7e845b2eccfcb27851ab8ac6abece774a6ff4d3de/psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456", size = 497003, upload-time = "2025-02-13T21:54:07.946Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/e6/2d26234410f8b8abdbf891c9da62bee396583f713fb9f3325a4760875d22/psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25", size = 238051, upload-time = "2025-02-13T21:54:12.36Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/8b/30f930733afe425e3cbfc0e1468a30a18942350c1a8816acfade80c005c4/psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da", size = 239535, upload-time = "2025-02-13T21:54:16.07Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/ed/d362e84620dd22876b55389248e522338ed1bf134a5edd3b8231d7207f6d/psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91", size = 275004, upload-time = "2025-02-13T21:54:18.662Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/b9/b0eb3f3cbcb734d930fdf839431606844a825b23eaf9a6ab371edac8162c/psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34", size = 277986, upload-time = "2025-02-13T21:54:21.811Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/a2/709e0fe2f093556c17fbafda93ac032257242cabcc7ff3369e2cb76a97aa/psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993", size = 279544, upload-time = "2025-02-13T21:54:24.68Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
@@ -91,9 +90,6 @@ class RepoDownloadProgress(BaseModel):
|
||||
def build_model_path(model_id: str) -> DirectoryPath:
|
||||
return EXO_HOME / "models" / model_id.replace("/", "--")
|
||||
|
||||
def exo_tmp() -> Path:
|
||||
return Path(tempfile.gettempdir())/"exo"
|
||||
|
||||
async def resolve_model_path_for_repo(repo_id: str) -> Path:
|
||||
return (await ensure_models_dir())/repo_id.replace("/", "--")
|
||||
|
||||
@@ -101,10 +97,6 @@ async def ensure_exo_home() -> Path:
|
||||
await aios.makedirs(EXO_HOME, exist_ok=True)
|
||||
return EXO_HOME
|
||||
|
||||
async def ensure_exo_tmp() -> Path:
|
||||
await aios.makedirs(exo_tmp(), exist_ok=True)
|
||||
return exo_tmp()
|
||||
|
||||
async def has_exo_home_read_access() -> bool:
|
||||
try:
|
||||
return await aios.access(EXO_HOME, os.R_OK)
|
||||
@@ -146,7 +138,9 @@ async def seed_models(seed_dir: Union[str, Path]):
|
||||
traceback.print_exc()
|
||||
|
||||
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main", recursive: bool = False) -> List[FileListEntry]:
|
||||
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
target_dir = (await ensure_models_dir())/"caches"/str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, 'r') as f:
|
||||
return TypeAdapter(List[FileListEntry]).validate_json(await f.read())
|
||||
@@ -198,22 +192,29 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
async def file_meta(repo_id: str, revision: str, path: str, redirected_location: str | None = None) -> Tuple[int, str]:
|
||||
# NOTE: huggingface broke the E-Tag so we can no longer assume E-Tag == sha256(file)
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path) if redirected_location is None else f"{get_hf_endpoint()}{redirected_location}"
|
||||
headers = await get_auth_headers()
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session, session.head(url, headers=headers) as r:
|
||||
if r.status == 307:
|
||||
redirected_location = r.headers.get('Location')
|
||||
return await file_meta(repo_id, revision, path, redirected_location)
|
||||
|
||||
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
|
||||
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
async def file_meta(repo_id: str, revision: str, path: str, redirected_location: str | None = None) -> Tuple[int, str]:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path) if redirected_location is None else f"{get_hf_endpoint()}{redirected_location}"
|
||||
headers = await get_auth_headers()
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session, session.head(url, headers=headers) as r:
|
||||
if r.status == 307:
|
||||
# Try to extract from X-Linked headers first (common for HF redirects)
|
||||
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
|
||||
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
||||
if content_length > 0 and etag is not None:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
# If not available, recurse with the redirect
|
||||
redirected_location = r.headers.get('Location')
|
||||
return await file_meta(repo_id, revision, path, redirected_location)
|
||||
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
|
||||
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
|
||||
async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
n_attempts = 30
|
||||
@@ -291,7 +292,8 @@ def calculate_repo_progress(shard: ShardMetadata, repo_id: str, revision: str, f
|
||||
)
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
||||
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
||||
target_dir = (await ensure_models_dir())/str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
|
||||
async with aiofiles.open(index_file, 'r') as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
@@ -9,16 +9,17 @@ from shared.types.events import (
|
||||
)
|
||||
from shared.types.profiling import NodePerformanceProfile
|
||||
from shared.types.worker.ops import (
|
||||
ExecuteTaskOp,
|
||||
RunnerOp,
|
||||
)
|
||||
from shared.utils import get_node_id_keypair
|
||||
from shared.utils import Keypair, get_node_id_keypair
|
||||
from worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from worker.plan import plan
|
||||
from worker.utils.profile import start_polling_node_metrics
|
||||
from worker.worker import Worker
|
||||
|
||||
|
||||
async def run(worker_state: Worker):
|
||||
async def run(worker_state: Worker, logger: logging.Logger):
|
||||
assert worker_state.global_events is not None
|
||||
|
||||
while True:
|
||||
@@ -42,15 +43,26 @@ async def run(worker_state: Worker):
|
||||
|
||||
# run the op, synchronously blocking for now
|
||||
if op is not None:
|
||||
async for event in worker_state.execute_op(op):
|
||||
await worker_state.event_publisher(event)
|
||||
logger.info(f'Executing op {op}')
|
||||
try:
|
||||
async for event in worker_state.execute_op(op):
|
||||
await worker_state.event_publisher(event)
|
||||
except Exception as e:
|
||||
if isinstance(op, ExecuteTaskOp):
|
||||
generator = worker_state.fail_task(e, runner_id=op.runner_id, task_id=op.task.task_id)
|
||||
else:
|
||||
generator = worker_state.fail_runner(e, runner_id=op.runner_id)
|
||||
|
||||
async for event in generator:
|
||||
await worker_state.event_publisher(event)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
node_id_keypair = get_node_id_keypair()
|
||||
node_id_keypair: Keypair = get_node_id_keypair()
|
||||
node_id = NodeId(node_id_keypair.to_peer_id().to_base58())
|
||||
logger: logging.Logger = logging.getLogger('worker_logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@@ -72,7 +84,7 @@ async def main():
|
||||
|
||||
worker = Worker(node_id, logger, shard_downloader, event_log_manager.worker_events, event_log_manager.global_events)
|
||||
|
||||
await run(worker)
|
||||
await run(worker, logger)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -95,7 +95,8 @@ def spin_down_runners(
|
||||
|
||||
num_spundown_nodes = 0
|
||||
for runner_id in instance.shard_assignments.runner_to_shard:
|
||||
if isinstance(state_runners[runner_id], InactiveRunnerStatus) and \
|
||||
if runner_id in state_runners and \
|
||||
isinstance(state_runners[runner_id], InactiveRunnerStatus) and \
|
||||
runner_id not in assigned_runners:
|
||||
num_spundown_nodes += 1
|
||||
# Suggested:
|
||||
|
||||
@@ -9,7 +9,7 @@ dependencies = [
|
||||
"huggingface_hub>=0.33.4",
|
||||
"mlx==0.26.3",
|
||||
"mlx-lm>=0.25.3",
|
||||
|
||||
"psutil>=7.0.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -34,7 +34,7 @@ async def runner_read_message() -> RunnerMessage:
|
||||
|
||||
line: bytes = await loop.run_in_executor(None, sys.stdin.buffer.readline)
|
||||
if not line: # This seems to be what triggers when we don't clean up the runner neatly and leave the process dangling.
|
||||
raise EOFError("No more data to read")
|
||||
raise EOFError("No more data to read when reading runner message")
|
||||
line = line.strip()
|
||||
|
||||
try:
|
||||
@@ -66,7 +66,7 @@ async def supervisor_read_response(
|
||||
line: str = line_bytes.decode("utf-8").strip()
|
||||
|
||||
if not line:
|
||||
raise EOFError("No more data to read")
|
||||
raise EOFError("No more data to read when reading response from runner")
|
||||
|
||||
try:
|
||||
return RunnerResponseTypeAdapter.validate_json(line)
|
||||
|
||||
@@ -10,7 +10,7 @@ import mlx.nn as nn
|
||||
from mlx_lm.generate import stream_generate # type: ignore
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx
|
||||
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx, mlx_force_oom
|
||||
from shared.openai_compat import FinishReason
|
||||
from shared.types.tasks import ChatCompletionTaskParams
|
||||
from shared.types.worker.commands_runner import (
|
||||
@@ -73,7 +73,7 @@ async def _mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or 100
|
||||
max_tokens = task.max_tokens or 1000
|
||||
generation_fn = partial(_generate_tokens, prompt, max_tokens)
|
||||
|
||||
future = loop.run_in_executor(mlx_executor, generation_fn)
|
||||
@@ -105,6 +105,12 @@ async def main():
|
||||
setup_message = ensure_type(init_message, SetupMessage)
|
||||
model_shard_meta = setup_message.model_shard_meta
|
||||
hosts = setup_message.hosts
|
||||
|
||||
# For testing - these are fake break conditions
|
||||
if model_shard_meta.immediate_exception:
|
||||
raise Exception('Fake exception - runner failed to spin up.')
|
||||
if model_shard_meta.should_timeout:
|
||||
await asyncio.sleep(model_shard_meta.should_timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
|
||||
@@ -127,7 +133,12 @@ async def main():
|
||||
# TODO: this is a hack, why are we only looking at the first message? should have a tokenizer
|
||||
prompt = task.messages[0]
|
||||
if prompt.content is not None and 'EXO RUNNER MUST FAIL' in prompt.content:
|
||||
runner_print('raising exception')
|
||||
raise Exception('Artificial runner exception - for testing purposes only.')
|
||||
if prompt.content is not None and 'EXO RUNNER MUST OOM' in prompt.content:
|
||||
mlx_force_oom()
|
||||
if prompt.content is not None and 'EXO RUNNER MUST TIMEOUT' in prompt.content:
|
||||
await asyncio.sleep(100)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
async for generation_response in _mlx_generate(
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
from logging import Logger
|
||||
from types import CoroutineType
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import psutil
|
||||
|
||||
from shared.types.common import CommandId, Host
|
||||
from shared.types.events.chunks import GenerationChunk, TokenChunk
|
||||
@@ -12,7 +14,6 @@ from shared.types.tasks import ChatCompletionTaskParams, Task
|
||||
from shared.types.worker.commands_runner import (
|
||||
ChatTaskMessage,
|
||||
ErrorResponse,
|
||||
ExitMessage,
|
||||
FinishedResponse,
|
||||
GenerationResponse,
|
||||
InitializedResponse,
|
||||
@@ -20,12 +21,19 @@ from shared.types.worker.commands_runner import (
|
||||
RunnerResponse,
|
||||
SetupMessage,
|
||||
)
|
||||
from shared.types.worker.common import RunnerError
|
||||
from shared.types.worker.shards import ShardMetadata
|
||||
from worker.runner.communication import (
|
||||
supervisor_read_response,
|
||||
supervisor_write_message,
|
||||
)
|
||||
from worker.runner.utils import get_runner_command
|
||||
from worker.runner.utils import (
|
||||
get_init_timeout,
|
||||
get_prefil_timeout,
|
||||
get_runner_command,
|
||||
get_token_generate_timeout,
|
||||
get_weights_size_kb,
|
||||
)
|
||||
|
||||
|
||||
class RunnerSupervisor:
|
||||
@@ -33,47 +41,52 @@ class RunnerSupervisor:
|
||||
RunnerSupervisor manages the lifecycle of a runner subprocess for model inference.
|
||||
Use the class method `create` to properly initialize an instance.
|
||||
"""
|
||||
|
||||
# TODO: Logger.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_shard_meta: ShardMetadata,
|
||||
hosts: list[Host],
|
||||
runner_process: asyncio.subprocess.Process,
|
||||
logger: Logger,
|
||||
):
|
||||
"""Private constructor. Use RunnerSupervisor.create() instead."""
|
||||
self.model_shard_meta: ShardMetadata = model_shard_meta
|
||||
self.hosts: list[Host] = hosts
|
||||
self.runner_process: asyncio.subprocess.Process = runner_process
|
||||
self.running: bool = True
|
||||
|
||||
self.stderr_task = asyncio.create_task(self._watch_stderr(logger))
|
||||
self.running_task: asyncio.Task[None] = asyncio.create_task(
|
||||
self._watch_runner()
|
||||
)
|
||||
self.logger = logger
|
||||
self.stderr_buffer: list[str] = [] # Accumulate stderr lines
|
||||
self.crash_detected: bool = False
|
||||
self.returncode: int | None = None
|
||||
self.stderr_outpu: str | None = None
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
model_shard_meta: ShardMetadata,
|
||||
hosts: list[Host],
|
||||
logger: Logger
|
||||
logger: Logger,
|
||||
initialize_timeout: Optional[float] = None,
|
||||
) -> "RunnerSupervisor":
|
||||
"""
|
||||
Create and initialize a RunnerSupervisor instance.
|
||||
The .create() classmethod pattern is used to ensure the constructor is asynchronous.
|
||||
"""
|
||||
cmd: list[str] = get_runner_command()
|
||||
|
||||
runner_process: asyncio.subprocess.Process = (
|
||||
await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=sys.stderr
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
|
||||
print(f'{model_shard_meta=}')
|
||||
logger.info(f'initializing mlx instance with {model_shard_meta=}')
|
||||
await supervisor_write_message(
|
||||
runner_process,
|
||||
SetupMessage(
|
||||
@@ -82,88 +95,159 @@ class RunnerSupervisor:
|
||||
),
|
||||
)
|
||||
|
||||
while True:
|
||||
line: RunnerResponse | None = await supervisor_read_response(
|
||||
runner_process
|
||||
)
|
||||
if line is None or isinstance(line, PrintResponse):
|
||||
# print(line)
|
||||
continue
|
||||
elif isinstance(line, ErrorResponse):
|
||||
raise Exception(line.error_type, line.error_message, line.traceback or "")
|
||||
else:
|
||||
assert isinstance(line, InitializedResponse)
|
||||
logger.info(f'Runner initialized in {line.time_taken} seconds')
|
||||
print(f'Runner initialized in {line.time_taken} seconds')
|
||||
break
|
||||
async def read_initialization_message() -> None:
|
||||
while True:
|
||||
line: RunnerResponse | None = await supervisor_read_response(
|
||||
runner_process
|
||||
)
|
||||
if line is None:
|
||||
continue
|
||||
elif isinstance(line, PrintResponse):
|
||||
logger.info(line)
|
||||
continue
|
||||
elif isinstance(line, ErrorResponse):
|
||||
raise RunnerError(line.error_type, line.error_message, line.traceback or "")
|
||||
elif isinstance(line, InitializedResponse):
|
||||
assert isinstance(line, InitializedResponse)
|
||||
logger.info(f'Runner initialized in {line.time_taken} seconds')
|
||||
break
|
||||
else:
|
||||
raise AssertionError(f'Non-valid line read from runner during initialization: {line}')
|
||||
|
||||
if not initialize_timeout:
|
||||
initialize_timeout = get_init_timeout(model_shard_meta)
|
||||
await asyncio.wait_for(read_initialization_message(), timeout=initialize_timeout)
|
||||
return cls(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts,
|
||||
runner_process=runner_process,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
async def astop(self) -> None:
|
||||
async def terminate() -> None:
|
||||
# Check if process is already dead before trying to terminate
|
||||
if self.runner_process.returncode is None:
|
||||
self.runner_process.terminate()
|
||||
|
||||
# Wait for the process to exit (or confirm it's already exited)
|
||||
try:
|
||||
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
# If terminate didn't work, force kill
|
||||
if self.runner_process.returncode is None:
|
||||
self.runner_process.kill()
|
||||
_ = await self.runner_process.wait()
|
||||
# Cancel the stderr monitoring task
|
||||
if not self.stderr_task.done():
|
||||
self.stderr_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self.stderr_task
|
||||
|
||||
if not self.healthy:
|
||||
print("Runner process is not healthy, killing...")
|
||||
await terminate()
|
||||
print('terminated')
|
||||
|
||||
if self.runner_process.stdout is not None:
|
||||
# Kill the process and all its children
|
||||
await self._kill_process_tree()
|
||||
|
||||
# Wait to make sure that the model has been unloaded from memory
|
||||
async def wait_for_memory_release() -> None:
|
||||
required_memory_bytes = get_weights_size_kb(self.model_shard_meta) * 1024
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
try:
|
||||
line = await asyncio.wait_for(
|
||||
self.runner_process.stdout.readline(), timeout=0.01
|
||||
)
|
||||
if not line:
|
||||
break
|
||||
print(f"Remaining stdout: {line.decode('utf-8').strip()}")
|
||||
except asyncio.TimeoutError:
|
||||
available_memory_bytes = psutil.virtual_memory().available
|
||||
if available_memory_bytes >= required_memory_bytes:
|
||||
break
|
||||
if asyncio.get_event_loop().time() - start_time > 30.0:
|
||||
self.logger.warning("Timeout waiting for memory release after 30 seconds")
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Only try to send ExitMessage if process is still alive
|
||||
if self.runner_process.returncode is None:
|
||||
try:
|
||||
# Give the process a moment to exit gracefully
|
||||
await supervisor_write_message(
|
||||
proc=self.runner_process, message=ExitMessage()
|
||||
)
|
||||
_ = await asyncio.wait_for(self.runner_process.wait(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
print("Runner process did not terminate, killing...")
|
||||
await terminate()
|
||||
except Exception:
|
||||
# If we can't write to the process (e.g., broken pipe), it's probably already dead
|
||||
pass
|
||||
|
||||
await wait_for_memory_release()
|
||||
self.running = False
|
||||
|
||||
async def _kill_process_tree(self) -> None:
|
||||
"""Kill the process and all its children forcefully."""
|
||||
if self.runner_process.returncode is not None:
|
||||
return # Process already dead
|
||||
|
||||
try:
|
||||
# Get the main process
|
||||
pid = self.runner_process.pid
|
||||
|
||||
# Find all child processes
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
# Kill all children first (bottom-up)
|
||||
for child in reversed(children):
|
||||
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
child.kill() # SIGKILL
|
||||
|
||||
# Kill the parent
|
||||
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
parent.kill() # SIGKILL
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
# Process already gone, try subprocess kill anyway
|
||||
self.runner_process.kill()
|
||||
|
||||
# Wait for the subprocess to exit
|
||||
try:
|
||||
await asyncio.wait_for(self.runner_process.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.error(f"Process {pid} did not exit after kill signal")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error killing process tree: {e}")
|
||||
|
||||
async def _watch_runner(self) -> None:
|
||||
_ = await self.runner_process.wait()
|
||||
returncode = await self.runner_process.wait()
|
||||
self.running = False
|
||||
if returncode != 0:
|
||||
self.crash_detected = True
|
||||
self.returncode = returncode # Will be picked up by _watch_stderr too
|
||||
|
||||
async def _watch_stderr(self, logger: Logger) -> None:
|
||||
assert self.runner_process.stderr is not None
|
||||
while self.running:
|
||||
try:
|
||||
line_bytes = await self.runner_process.stderr.readline()
|
||||
if not line_bytes:
|
||||
break # EOF
|
||||
line = line_bytes.decode('utf-8').strip()
|
||||
self.stderr_buffer.append(line)
|
||||
logger.error(f"Runner stderr: {line}")
|
||||
# Detect common crash patterns (extend as needed, e.g., for OOM: "Killed" or "Out of memory")
|
||||
|
||||
self.crash_detected = True
|
||||
self.stderr_output = "\n".join(self.stderr_buffer)
|
||||
logger.critical(f"Runner crash detected: {self.stderr_output}")
|
||||
# Don't raise here—let callers (e.g., stream_response) detect via healthy/returncode
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading runner stderr: {e}")
|
||||
break
|
||||
|
||||
# After EOF, inspect returncode for confirmation (Unix-like: negative == signal)
|
||||
returncode = self.runner_process.returncode
|
||||
if returncode is not None and returncode != 0:
|
||||
self.crash_detected = True
|
||||
self.returncode = returncode
|
||||
self.stderr_output = "\n".join(self.stderr_buffer)
|
||||
|
||||
def _raise_if_crashed(self) -> None:
|
||||
if self.crash_detected:
|
||||
self.logger.error(f'Error {self.returncode}: {self.stderr_output}')
|
||||
raise RunnerError(
|
||||
error_type="RunnerCrash",
|
||||
error_message=self.stderr_output,
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.running:
|
||||
print(
|
||||
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process."
|
||||
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process tree."
|
||||
)
|
||||
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
self.runner_process.kill()
|
||||
# Can't use async in __del__, so use psutil directly
|
||||
try:
|
||||
pid = self.runner_process.pid
|
||||
if pid:
|
||||
parent = psutil.Process(pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in reversed(children):
|
||||
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
child.kill()
|
||||
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
parent.kill()
|
||||
except Exception:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
self.runner_process.kill()
|
||||
|
||||
@property
|
||||
def healthy(self) -> bool:
|
||||
@@ -178,7 +262,7 @@ class RunnerSupervisor:
|
||||
async def stream_response(
|
||||
self,
|
||||
task: Task,
|
||||
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, # fyi this is async now
|
||||
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None, # fyi this is async now
|
||||
) -> AsyncGenerator[GenerationChunk]:
|
||||
"""
|
||||
Streams a chat request from the model.
|
||||
@@ -187,50 +271,52 @@ class RunnerSupervisor:
|
||||
"""
|
||||
if not self.healthy:
|
||||
raise RuntimeError("Runner process was found to be dead")
|
||||
|
||||
task_params = task.task_params
|
||||
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
|
||||
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
|
||||
await supervisor_write_message(
|
||||
proc=self.runner_process,
|
||||
message=ChatTaskMessage(
|
||||
task_data=task_params,
|
||||
),
|
||||
)
|
||||
|
||||
# This is easy for now. If we need more reliability, the runner can have a new 'ready' message type.
|
||||
if request_started_callback is not None:
|
||||
await request_started_callback()
|
||||
|
||||
|
||||
prefil_timeout = get_prefil_timeout(self.model_shard_meta)
|
||||
token_timeout = get_token_generate_timeout(self.model_shard_meta)
|
||||
timeout = prefil_timeout
|
||||
while True:
|
||||
line: RunnerResponse | None = await supervisor_read_response(
|
||||
self.runner_process
|
||||
)
|
||||
if line is None:
|
||||
continue
|
||||
else:
|
||||
match line:
|
||||
case GenerationResponse(
|
||||
text=text, token=token, finish_reason=finish_reason
|
||||
):
|
||||
yield TokenChunk(
|
||||
command_id=CommandId(task.command_id),
|
||||
idx=token,
|
||||
model=self.model_shard_meta.model_meta.model_id,
|
||||
text=text,
|
||||
token_id=token,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
case InitializedResponse():
|
||||
raise ValueError('Initialized Response read during streaming flow')
|
||||
case FinishedResponse():
|
||||
break
|
||||
case PrintResponse(text=text):
|
||||
print(f"runner printed: {text}")
|
||||
case ErrorResponse(
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
traceback=traceback,
|
||||
):
|
||||
await self.astop()
|
||||
raise Exception(error_type, error_message, traceback or "")
|
||||
try:
|
||||
line: RunnerResponse | None = await asyncio.wait_for(supervisor_read_response(
|
||||
self.runner_process
|
||||
), timeout=timeout)
|
||||
if line is None:
|
||||
continue
|
||||
except (asyncio.TimeoutError, EOFError) as e:
|
||||
self._raise_if_crashed()
|
||||
raise RunnerError(
|
||||
error_type=type(e).__name__,
|
||||
error_message=str(e),
|
||||
traceback="",
|
||||
) from e
|
||||
match line:
|
||||
case GenerationResponse():
|
||||
yield TokenChunk(
|
||||
command_id=CommandId(task.command_id),
|
||||
idx=line.token,
|
||||
model=self.model_shard_meta.model_meta.model_id,
|
||||
text=line.text,
|
||||
token_id=line.token,
|
||||
finish_reason=line.finish_reason,
|
||||
)
|
||||
timeout = token_timeout
|
||||
case InitializedResponse():
|
||||
raise ValueError('Initialized Response read during streaming flow')
|
||||
case FinishedResponse():
|
||||
break
|
||||
case PrintResponse():
|
||||
# print(f"runner printed: {line.text}")
|
||||
self.logger.info(f"runner printed: {line.text}")
|
||||
case ErrorResponse():
|
||||
await self.astop()
|
||||
raise RunnerError(line.error_type, line.error_message, line.traceback or "")
|
||||
@@ -1,6 +1,34 @@
|
||||
import sys
|
||||
|
||||
from shared.constants import LB_DISK_GBPS, LB_MEMBW_GBPS, LB_TFLOPS
|
||||
from shared.types.worker.shards import ShardMetadata
|
||||
|
||||
|
||||
def get_runner_command() -> list[str]:
|
||||
python = sys.executable
|
||||
return [python, "-m", "worker.runner.runner"]
|
||||
|
||||
def get_weights_size_kb(model_shard_meta: ShardMetadata) -> float:
|
||||
return (model_shard_meta.end_layer - model_shard_meta.start_layer) / model_shard_meta.n_layers * model_shard_meta.model_meta.storage_size_kilobytes
|
||||
|
||||
def get_init_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
weights_size_kb = get_weights_size_kb(model_shard_meta)
|
||||
|
||||
kbps_read = 1024 * 1024 * LB_DISK_GBPS / 3
|
||||
|
||||
return weights_size_kb / kbps_read + 2.0
|
||||
|
||||
def get_prefil_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
weights_size_gb = get_weights_size_kb(model_shard_meta) / (1024 * 1024)
|
||||
|
||||
tokens = 1000 # constant for now - the prompt is only tokenized in the device...
|
||||
prompt_gflops = tokens * weights_size_gb * 2
|
||||
|
||||
return LB_TFLOPS / (1024 * prompt_gflops) * 3 + 10.0
|
||||
|
||||
def get_token_generate_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
weights_size_kb = get_weights_size_kb(model_shard_meta)
|
||||
|
||||
kbps_read = 1024 * 1024 * LB_MEMBW_GBPS / 3
|
||||
|
||||
return weights_size_kb / kbps_read + 2.0
|
||||
@@ -112,7 +112,6 @@ def instance(pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], h
|
||||
|
||||
@pytest.fixture
|
||||
def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
|
||||
"""Creates ChatCompletionParams with the given message"""
|
||||
return ChatCompletionTaskParams(
|
||||
model="gpt-4",
|
||||
messages=[ChatCompletionMessage(role="user", content=user_message)],
|
||||
@@ -121,19 +120,19 @@ def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
|
||||
|
||||
@pytest.fixture
|
||||
def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
|
||||
def _chat_completion_task(instance_id: Optional[InstanceId] = None, task_id: Optional[TaskId] = None) -> ChatCompletionTask:
|
||||
if instance_id is None:
|
||||
instance_id = INSTANCE_1_ID
|
||||
if task_id is None:
|
||||
task_id = TASK_1_ID
|
||||
def _chat_completion_task(
|
||||
instance_id: Optional[InstanceId] = None,
|
||||
task_id: Optional[TaskId] = None,
|
||||
user_message: str = "Hello"
|
||||
) -> ChatCompletionTask:
|
||||
resolved_instance_id = instance_id if instance_id is not None else INSTANCE_1_ID
|
||||
resolved_task_id = task_id if task_id is not None else TASK_1_ID
|
||||
return ChatCompletionTask(
|
||||
task_id=task_id,
|
||||
task_id=resolved_task_id,
|
||||
command_id=COMMAND_1_ID,
|
||||
instance_id=instance_id,
|
||||
instance_id=resolved_instance_id,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=completion_create_params
|
||||
)
|
||||
return _chat_completion_task
|
||||
|
||||
|
||||
|
||||
0
worker/tests/test_handlers/__init__.py
Normal file
0
worker/tests/test_handlers/__init__.py
Normal file
@@ -1,28 +1,46 @@
|
||||
## Tests for worker state handlers
|
||||
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.types.events import (
|
||||
RunnerStatusUpdated,
|
||||
TaskFailed,
|
||||
TaskStateUpdated,
|
||||
)
|
||||
from shared.types.tasks import ChatCompletionTask, TaskStatus
|
||||
from shared.types.tasks import ChatCompletionTask
|
||||
from shared.types.worker.common import RunnerError
|
||||
from shared.types.worker.instances import Instance
|
||||
from shared.types.worker.ops import (
|
||||
ExecuteTaskOp,
|
||||
)
|
||||
from shared.types.worker.runners import (
|
||||
FailedRunnerStatus,
|
||||
RunningRunnerStatus,
|
||||
RunnerUpOp,
|
||||
)
|
||||
from worker.main import Worker
|
||||
from worker.tests.constants import RUNNER_1_ID
|
||||
from worker.tests.test_handlers.utils import read_events_op
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_up_fails(
|
||||
worker_with_assigned_runner: tuple[Worker, Instance],
|
||||
chat_completion_task: Callable[[], ChatCompletionTask]):
|
||||
worker, _ = worker_with_assigned_runner
|
||||
worker.assigned_runners[RUNNER_1_ID].shard_metadata.immediate_exception = True
|
||||
|
||||
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
|
||||
|
||||
with pytest.raises(RunnerError):
|
||||
await read_events_op(worker, runner_up_op)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_up_timeouts(
|
||||
worker_with_assigned_runner: tuple[Worker, Instance],
|
||||
chat_completion_task: Callable[[], ChatCompletionTask]):
|
||||
worker, _ = worker_with_assigned_runner
|
||||
worker.assigned_runners[RUNNER_1_ID].shard_metadata.should_timeout = 10
|
||||
|
||||
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await read_events_op(worker, runner_up_op)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_fails(
|
||||
worker_with_running_runner: tuple[Worker, Instance],
|
||||
@@ -38,24 +56,27 @@ async def test_execute_task_fails(
|
||||
task=task
|
||||
)
|
||||
|
||||
events = await read_events_op(worker, execute_task_op)
|
||||
with pytest.raises(RunnerError):
|
||||
await read_events_op(worker, execute_task_op)
|
||||
|
||||
assert len(events) == 5
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_timeouts(
|
||||
worker_with_running_runner: tuple[Worker, Instance],
|
||||
chat_completion_task: Callable[[], ChatCompletionTask]):
|
||||
worker, _ = worker_with_running_runner
|
||||
|
||||
print(events)
|
||||
task = chat_completion_task()
|
||||
messages = task.task_params.messages
|
||||
messages[0].content = 'Artificial prompt: EXO RUNNER MUST TIMEOUT'
|
||||
|
||||
assert isinstance(events[0], RunnerStatusUpdated)
|
||||
assert isinstance(events[0].runner_status, RunningRunnerStatus) # It tried to start.
|
||||
execute_task_op = ExecuteTaskOp(
|
||||
runner_id=RUNNER_1_ID,
|
||||
task=task
|
||||
)
|
||||
|
||||
assert isinstance(events[1], TaskStateUpdated)
|
||||
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
|
||||
with pytest.raises(RunnerError): # At the moment this is a RunnerError that says 'TimeoutError'.
|
||||
await read_events_op(worker, execute_task_op)
|
||||
|
||||
assert isinstance(events[2], TaskStateUpdated)
|
||||
assert events[2].task_status == TaskStatus.FAILED # Task marked as failed.
|
||||
|
||||
assert isinstance(events[3], TaskFailed)
|
||||
|
||||
assert isinstance(events[4], RunnerStatusUpdated)
|
||||
assert isinstance(events[4].runner_status, FailedRunnerStatus) # It should have failed.
|
||||
|
||||
# TODO: Much more to do here!
|
||||
# TODO: Much more to do here!
|
||||
# runner assigned download stuff
|
||||
0
worker/tests/test_integration/__init__.py
Normal file
0
worker/tests/test_integration/__init__.py
Normal file
@@ -29,7 +29,7 @@ def worker_running(logger: Logger) -> Callable[[NodeId], Awaitable[tuple[Worker,
|
||||
|
||||
shard_downloader = NoopShardDownloader()
|
||||
worker = Worker(node_id, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker))
|
||||
asyncio.create_task(run(worker, logger))
|
||||
|
||||
return worker, global_events
|
||||
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
from typing import Tuple
|
||||
from typing import Callable, Optional, Tuple, TypeVar
|
||||
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.types.events import ChunkGenerated, TaskStateUpdated
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.tasks import TaskStatus
|
||||
from shared.types.tasks import TaskId, TaskStatus
|
||||
|
||||
|
||||
async def read_streaming_response(global_events: AsyncSQLiteEventStorage) -> Tuple[bool, bool, str]:
|
||||
async def read_streaming_response(global_events: AsyncSQLiteEventStorage, filter_task: Optional[TaskId] = None) -> Tuple[bool, bool, str]:
|
||||
# Read off all events - these should be our GenerationChunk events
|
||||
seen_task_started, seen_task_finished = 0, 0
|
||||
response_string = ''
|
||||
finish_reason: str | None = None
|
||||
|
||||
idx = 0
|
||||
if not filter_task:
|
||||
idx = await global_events.get_last_idx()
|
||||
else:
|
||||
found = False
|
||||
idx = 0
|
||||
while not found:
|
||||
events = await global_events.get_events_since(idx)
|
||||
|
||||
for event in events:
|
||||
if isinstance(event.event, TaskStateUpdated) and event.event.task_status == TaskStatus.RUNNING and event.event.task_id == filter_task:
|
||||
found = True
|
||||
idx = event.idx_in_log - 1
|
||||
break
|
||||
|
||||
print(f'START IDX {idx}')
|
||||
|
||||
while not finish_reason:
|
||||
events = await global_events.get_events_since(idx)
|
||||
if len(events) == 0:
|
||||
@@ -41,4 +56,26 @@ async def read_streaming_response(global_events: AsyncSQLiteEventStorage) -> Tup
|
||||
|
||||
print(f'event log: {await global_events.get_events_since(0)}')
|
||||
|
||||
return seen_task_started == 1, seen_task_finished == 1, response_string
|
||||
return seen_task_started == 1, seen_task_finished == 1, response_string
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
async def until_event_with_timeout(
|
||||
global_events: AsyncSQLiteEventStorage,
|
||||
event_type: type[T],
|
||||
multiplicity: int = 1,
|
||||
condition: Callable[[T], bool] = lambda x: True,
|
||||
) -> None:
|
||||
idx = await global_events.get_last_idx()
|
||||
times_seen = 0
|
||||
while True:
|
||||
events = await global_events.get_events_since(idx)
|
||||
if events:
|
||||
for wrapped_event in events:
|
||||
if isinstance(wrapped_event.event, event_type) and condition(wrapped_event.event):
|
||||
times_seen += 1
|
||||
if times_seen >= multiplicity:
|
||||
return
|
||||
idx = events[-1].idx_in_log
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
@@ -1,351 +0,0 @@
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
|
||||
from shared.types.common import Host, NodeId
|
||||
from shared.types.events import (
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskCreated,
|
||||
)
|
||||
from shared.types.events.chunks import TokenChunk
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceStatus,
|
||||
ShardAssignments,
|
||||
)
|
||||
from shared.types.worker.runners import (
|
||||
DownloadingRunnerStatus,
|
||||
# RunningRunnerStatus,
|
||||
FailedRunnerStatus,
|
||||
InactiveRunnerStatus,
|
||||
LoadedRunnerStatus,
|
||||
)
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.common import AssignedRunner
|
||||
from worker.download.shard_downloader import NoopShardDownloader
|
||||
from worker.main import run
|
||||
from worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MASTER_NODE_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
TASK_1_ID,
|
||||
TASK_2_ID,
|
||||
)
|
||||
from worker.tests.test_integration.integration_utils import (
|
||||
read_streaming_response,
|
||||
)
|
||||
from worker.worker import Worker
|
||||
|
||||
|
||||
async def test_runner_assigned(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
|
||||
):
|
||||
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.INACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Ensure the worker has taken the correct action
|
||||
assert len(worker.assigned_runners) == 1
|
||||
assert RUNNER_1_ID in worker.assigned_runners
|
||||
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, InactiveRunnerStatus)
|
||||
|
||||
# Ensure the correct events have been emitted
|
||||
events = await global_events.get_events_since(0)
|
||||
assert len(events) >= 3 # len(events) is 4 if it's already downloaded. It is > 4 if there have to be download events.
|
||||
|
||||
assert isinstance(events[1].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[1].event.runner_status, DownloadingRunnerStatus)
|
||||
assert isinstance(events[-1].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[-1].event.runner_status, InactiveRunnerStatus)
|
||||
|
||||
# Ensure state is correct
|
||||
assert isinstance(worker.state.runners[RUNNER_1_ID], InactiveRunnerStatus)
|
||||
|
||||
async def test_runner_assigned_active(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
assert len(worker.assigned_runners) == 1
|
||||
assert RUNNER_1_ID in worker.assigned_runners
|
||||
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, LoadedRunnerStatus)
|
||||
|
||||
# Ensure the correct events have been emitted
|
||||
events = await global_events.get_events_since(0)
|
||||
assert len(events) >= 4 # len(events) is 5 if it's already downloaded. It is > 5 if there have to be download events.
|
||||
assert isinstance(events[1].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[1].event.runner_status, DownloadingRunnerStatus)
|
||||
assert isinstance(events[-2].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[-2].event.runner_status, InactiveRunnerStatus)
|
||||
assert isinstance(events[-1].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[-1].event.runner_status, LoadedRunnerStatus)
|
||||
|
||||
# Ensure state is correct
|
||||
assert isinstance(worker.state.runners[RUNNER_1_ID], LoadedRunnerStatus)
|
||||
|
||||
# Ensure that the runner has been created and it can stream tokens.
|
||||
supervisor = next(iter(worker.assigned_runners.values())).runner
|
||||
assert supervisor is not None
|
||||
assert supervisor.healthy
|
||||
|
||||
full_response = ''
|
||||
|
||||
async for chunk in supervisor.stream_response(task=chat_completion_task(INSTANCE_1_ID, TASK_1_ID)):
|
||||
if isinstance(chunk, TokenChunk):
|
||||
full_response += chunk.text
|
||||
|
||||
assert "tokyo" in full_response.lower(), (
|
||||
f"Expected 'Tokyo' in response, but got: {full_response}"
|
||||
)
|
||||
|
||||
async def test_runner_assigned_wrong_node(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
|
||||
):
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value = instance(INSTANCE_1_ID, NODE_B, RUNNER_1_ID)
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(worker.assigned_runners) == 0
|
||||
|
||||
# Ensure the correct events have been emitted
|
||||
events = await global_events.get_events_since(0)
|
||||
assert len(events) == 1
|
||||
# No RunnerStatusUpdated event should be emitted
|
||||
|
||||
# Ensure state is correct
|
||||
assert len(worker.state.runners) == 0
|
||||
|
||||
async def test_runner_unassigns(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance]
|
||||
):
|
||||
worker, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# already tested by test_runner_assigned_active
|
||||
assert len(worker.assigned_runners) == 1
|
||||
assert RUNNER_1_ID in worker.assigned_runners
|
||||
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, LoadedRunnerStatus)
|
||||
|
||||
# Ensure the correct events have been emitted (creation)
|
||||
events = await global_events.get_events_since(0)
|
||||
assert len(events) >= 4
|
||||
assert isinstance(events[-1].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[-1].event.runner_status, LoadedRunnerStatus)
|
||||
|
||||
# Ensure state is correct
|
||||
assert isinstance(worker.state.runners[RUNNER_1_ID], LoadedRunnerStatus)
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(instance_id=instance_value.instance_id)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
assert len(worker.assigned_runners) == 0
|
||||
|
||||
# Ensure the correct events have been emitted (deletion)
|
||||
events = await global_events.get_events_since(0)
|
||||
assert isinstance(events[-1].event, RunnerDeleted)
|
||||
# After deletion, runner should be removed from state.runners
|
||||
assert len(worker.state.runners) == 0
|
||||
|
||||
|
||||
|
||||
async def test_runner_respawn(
|
||||
logger: Logger,
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
hosts: Callable[[int], list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
shard_downloader = NoopShardDownloader()
|
||||
|
||||
global_events = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker1))
|
||||
|
||||
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker2))
|
||||
|
||||
## Instance
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_id,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: pipeline_shard_meta(2, 0),
|
||||
RUNNER_2_ID: pipeline_shard_meta(2, 1)
|
||||
},
|
||||
node_to_runner={
|
||||
NODE_A: RUNNER_1_ID,
|
||||
NODE_B: RUNNER_2_ID
|
||||
}
|
||||
)
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2)
|
||||
)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance
|
||||
),
|
||||
TaskCreated(
|
||||
task_id=task.task_id,
|
||||
task=task
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
assert 'tokyo' in response_string.lower()
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
idx = await global_events.get_last_idx()
|
||||
|
||||
assigned_runner: AssignedRunner = worker1.assigned_runners[RUNNER_1_ID]
|
||||
assert assigned_runner.runner is not None
|
||||
assigned_runner.runner.runner_process.kill()
|
||||
|
||||
# Wait for the process to actually be detected as dead or cleaned up
|
||||
for _ in range(100): # Wait up to 1 second
|
||||
await asyncio.sleep(0.01)
|
||||
# The worker may clean up the runner (set to None) when it detects it's dead
|
||||
if assigned_runner.runner and not assigned_runner.runner.healthy:
|
||||
break
|
||||
else:
|
||||
raise AssertionError("Runner should have been detected as unhealthy or cleaned up after kill()")
|
||||
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
events = await global_events.get_events_since(idx)
|
||||
# assert len(events) == 2
|
||||
assert isinstance(events[0].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[0].event.runner_status, FailedRunnerStatus)
|
||||
|
||||
assert isinstance(events[1].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[1].event.runner_status, InactiveRunnerStatus)
|
||||
assert events[1].event.runner_id == RUNNER_2_ID
|
||||
|
||||
assert isinstance(events[2].event, RunnerStatusUpdated)
|
||||
assert isinstance(events[2].event.runner_status, InactiveRunnerStatus)
|
||||
assert events[2].event.runner_id == RUNNER_1_ID
|
||||
|
||||
|
||||
for event in [events[3].event, events[4].event]:
|
||||
assert isinstance(event, RunnerStatusUpdated)
|
||||
assert isinstance(event.runner_status, LoadedRunnerStatus)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_2_ID)
|
||||
await global_events.append_events(
|
||||
[
|
||||
TaskCreated(
|
||||
task_id=task.task_id,
|
||||
task=task
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
assert 'tokyo' in response_string.lower()
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
@@ -62,6 +62,7 @@ async def test_runner_inference(
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
# TODO: This needs to get fixed - sometimes it misses the 'starting' event.
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
@@ -93,10 +94,10 @@ async def test_2_runner_inference(
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker1))
|
||||
asyncio.create_task(run(worker1, logger))
|
||||
|
||||
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker2))
|
||||
asyncio.create_task(run(worker2, logger))
|
||||
|
||||
## Instance
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
@@ -171,10 +172,10 @@ async def test_2_runner_multi_message(
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker1))
|
||||
asyncio.create_task(run(worker1, logger))
|
||||
|
||||
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker2))
|
||||
asyncio.create_task(run(worker2, logger))
|
||||
|
||||
## Instance
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
|
||||
@@ -17,6 +17,7 @@ from shared.types.events import (
|
||||
TaskCreated,
|
||||
TaskStateUpdated,
|
||||
)
|
||||
from shared.types.events._events import TaskFailed
|
||||
from shared.types.events.chunks import GenerationChunk, TokenChunk
|
||||
from shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
@@ -34,6 +35,7 @@ from worker.tests.constants import (
|
||||
RUNNER_1_ID,
|
||||
TASK_1_ID,
|
||||
)
|
||||
from worker.tests.test_integration.integration_utils import until_event_with_timeout
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -41,14 +43,13 @@ def user_message():
|
||||
"""Override this fixture in tests to customize the message"""
|
||||
return "Who is the longest ruling monarch of England?"
|
||||
|
||||
# TODO: Make this all monkeypatched instead.
|
||||
|
||||
async def test_stream_response_failed_always(
|
||||
monkeypatch: MonkeyPatch,
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
) -> None:
|
||||
_, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
@@ -74,7 +75,7 @@ async def test_stream_response_failed_always(
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(5.)
|
||||
await until_event_with_timeout(global_events, InstanceDeleted)
|
||||
|
||||
|
||||
events = await global_events.get_events_since(0)
|
||||
@@ -133,7 +134,7 @@ async def test_stream_response_failed_once(
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(5.)
|
||||
await until_event_with_timeout(global_events, ChunkGenerated, 1, condition=lambda x: isinstance(x.chunk, TokenChunk) and x.chunk.finish_reason is not None)
|
||||
|
||||
# TODO: The ideal with this test is if we had some tooling to scroll through the state, and say
|
||||
# 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero'
|
||||
@@ -179,65 +180,41 @@ async def test_stream_response_failed_once(
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
# async def test_stream_response_timeout(
|
||||
# monkeypatch: MonkeyPatch,
|
||||
# worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
# instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
# chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
# ):
|
||||
# async def mock_stream_response(
|
||||
# self: RunnerSupervisor,
|
||||
# task: Task,
|
||||
# request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
||||
# ) -> AsyncGenerator[GenerationChunk]:
|
||||
# # TODO: Also a test where we yield a few chunks and then time out.
|
||||
# print('sleeping starting')
|
||||
# await asyncio.sleep(4.)
|
||||
# print('sleeping finished')
|
||||
# return
|
||||
# yield
|
||||
async def test_stream_response_timeout(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
_, global_events = await worker_running(NODE_A)
|
||||
|
||||
# monkeypatch.setattr(RunnerSupervisor, 'stream_response', mock_stream_response)
|
||||
|
||||
# worker, global_events = await worker_running(NODE_A)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
|
||||
# instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
# instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = 'EXO RUNNER MUST TIMEOUT'
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(instance=instance_value),
|
||||
TaskCreated(task_id=task.task_id, task=task)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
# task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
# await global_events.append_events(
|
||||
# [
|
||||
# InstanceCreated(instance=instance_value),
|
||||
# TaskCreated(task_id=task.task_id, task=task)
|
||||
# ],
|
||||
# origin=MASTER_NODE_ID
|
||||
# )
|
||||
await until_event_with_timeout(global_events, TaskFailed, multiplicity=3)
|
||||
|
||||
# await asyncio.sleep(7.)
|
||||
|
||||
events = await global_events.get_events_since(0)
|
||||
print(events)
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
||||
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3
|
||||
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_message.lower()]) == 3
|
||||
|
||||
# # as we reset the failures back to zero when we have a successful inference.
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance_value.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
# # print('ASSERTION ERR:')
|
||||
# # print(worker.assigned_runners[RUNNER_1_ID].failures[1][1])
|
||||
|
||||
# assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
|
||||
# assert worker.state.tasks[TASK_1_ID].error_type is None
|
||||
# assert worker.state.tasks[TASK_1_ID].error_message is None
|
||||
|
||||
# events = await global_events.get_events_since(0)
|
||||
# print(events)
|
||||
# assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 1
|
||||
# assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 1
|
||||
# assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 1
|
||||
|
||||
# await global_events.append_events(
|
||||
# [
|
||||
# InstanceDeleted(
|
||||
# instance_id=instance_value.instance_id,
|
||||
# ),
|
||||
# ],
|
||||
# origin=MASTER_NODE_ID
|
||||
# )
|
||||
|
||||
# await asyncio.sleep(0.3)
|
||||
await asyncio.sleep(0.3)
|
||||
85
worker/tests/test_integration/test_instantiation.py
Normal file
85
worker/tests/test_integration/test_instantiation.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.types.common import NodeId
|
||||
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.types.events import (
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
RunnerStatusUpdated,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceStatus,
|
||||
)
|
||||
from shared.types.worker.runners import (
|
||||
FailedRunnerStatus,
|
||||
)
|
||||
from worker.main import Worker
|
||||
from worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MASTER_NODE_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from worker.tests.test_integration.integration_utils import until_event_with_timeout
|
||||
|
||||
|
||||
async def test_runner_spinup_exception(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
):
|
||||
_, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].immediate_exception = True
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
# Ensure the correct events have been emitted
|
||||
events = await global_events.get_events_since(0)
|
||||
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
|
||||
|
||||
async def test_runner_spinup_timeout(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
):
|
||||
_, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await until_event_with_timeout(global_events, RunnerStatusUpdated, multiplicity=3, condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus))
|
||||
|
||||
# Ensure the correct events have been emitted
|
||||
events = await global_events.get_events_since(0)
|
||||
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
85
worker/tests/test_integration/test_instantiation_sad.py
Normal file
85
worker/tests/test_integration/test_instantiation_sad.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
||||
from shared.types.common import NodeId
|
||||
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.types.events import (
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
RunnerStatusUpdated,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceStatus,
|
||||
)
|
||||
from shared.types.worker.runners import (
|
||||
FailedRunnerStatus,
|
||||
)
|
||||
from worker.main import Worker
|
||||
from worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MASTER_NODE_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from worker.tests.test_integration.integration_utils import until_event_with_timeout
|
||||
|
||||
|
||||
async def test_runner_spinup_exception(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
):
|
||||
_, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].immediate_exception = True
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
# Ensure the correct events have been emitted
|
||||
events = await global_events.get_events_since(0)
|
||||
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
|
||||
|
||||
async def test_runner_spinup_timeout(
|
||||
worker_running: Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]],
|
||||
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
|
||||
):
|
||||
_, global_events = await worker_running(NODE_A)
|
||||
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance_value
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await until_event_with_timeout(global_events, RunnerStatusUpdated, multiplicity=3, condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus))
|
||||
|
||||
# Ensure the correct events have been emitted
|
||||
events = await global_events.get_events_since(0)
|
||||
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
258
worker/tests/test_multimodel/test_inference_llama70B.py
Normal file
258
worker/tests/test_multimodel/test_inference_llama70B.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
|
||||
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
|
||||
from shared.models.model_meta import get_model_meta
|
||||
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from shared.types.common import Host
|
||||
from shared.types.events import (
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskCreated,
|
||||
)
|
||||
from shared.types.models import ModelId, ModelMetadata
|
||||
from shared.types.tasks import ChatCompletionTask, Task, TaskId, TaskStatus, TaskType
|
||||
from shared.types.worker.common import InstanceId
|
||||
from shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceStatus,
|
||||
ShardAssignments,
|
||||
)
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.download.shard_downloader import NoopShardDownloader
|
||||
from worker.main import run
|
||||
from worker.tests.constants import (
|
||||
COMMAND_1_ID,
|
||||
COMMAND_2_ID,
|
||||
INSTANCE_1_ID,
|
||||
MASTER_NODE_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
TASK_1_ID,
|
||||
TASK_2_ID,
|
||||
)
|
||||
from worker.tests.test_integration.integration_utils import (
|
||||
read_streaming_response,
|
||||
)
|
||||
from worker.worker import Worker
|
||||
|
||||
MODEL_ID = 'mlx-community/Llama-3.3-70B-Instruct-4bit'
|
||||
|
||||
@pytest.fixture
|
||||
async def model_meta() -> ModelMetadata:
|
||||
return await get_model_meta(MODEL_ID)
|
||||
|
||||
async def test_2_runner_inference(
|
||||
logger: Logger,
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
hosts: Callable[[int], list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
shard_downloader = NoopShardDownloader()
|
||||
|
||||
global_events = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker1, logger))
|
||||
|
||||
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker2, logger))
|
||||
|
||||
## Instance
|
||||
model_id = ModelId(MODEL_ID)
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_id,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: pipeline_shard_meta(2, 0),
|
||||
RUNNER_2_ID: pipeline_shard_meta(2, 1)
|
||||
},
|
||||
node_to_runner={
|
||||
NODE_A: RUNNER_1_ID,
|
||||
NODE_B: RUNNER_2_ID
|
||||
}
|
||||
)
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2)
|
||||
)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = 'Can you explain to me how a bubble sort works, speaking as if you are a fairy.'
|
||||
task.task_params.max_tokens = 1000
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance
|
||||
),
|
||||
TaskCreated(
|
||||
task_id=task.task_id,
|
||||
task=task
|
||||
)
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started, seen_task_finished, response_string = await read_streaming_response(global_events)
|
||||
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
assert 'swap' in response_string.lower()
|
||||
|
||||
|
||||
idx = await global_events.get_last_idx()
|
||||
await asyncio.sleep(1.0)
|
||||
events = await global_events.get_events_since(idx)
|
||||
assert len(events) == 0
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
|
||||
|
||||
|
||||
async def test_parallel_inference(
|
||||
logger: Logger,
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
hosts: Callable[[int], list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task]
|
||||
):
|
||||
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
||||
await event_log_manager.initialize()
|
||||
shard_downloader = NoopShardDownloader()
|
||||
|
||||
global_events = event_log_manager.global_events
|
||||
await global_events.delete_all_events()
|
||||
|
||||
worker1 = Worker(NODE_A, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker1, logger))
|
||||
|
||||
worker2 = Worker(NODE_B, logger=logger, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events)
|
||||
asyncio.create_task(run(worker2, logger))
|
||||
|
||||
## Instance
|
||||
model_id = ModelId(MODEL_ID)
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_id,
|
||||
runner_to_shard={
|
||||
RUNNER_1_ID: pipeline_shard_meta(2, 0),
|
||||
RUNNER_2_ID: pipeline_shard_meta(2, 1)
|
||||
},
|
||||
node_to_runner={
|
||||
NODE_A: RUNNER_1_ID,
|
||||
NODE_B: RUNNER_2_ID
|
||||
}
|
||||
)
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2)
|
||||
)
|
||||
|
||||
completion_create_params_1 = ChatCompletionTaskParams(
|
||||
model="gpt-4",
|
||||
messages=[ChatCompletionMessage(role="user", content='Tell me a haiku that uses the word "pond".')],
|
||||
stream=True,
|
||||
max_tokens=1000
|
||||
)
|
||||
task1 = ChatCompletionTask(
|
||||
task_id=TASK_1_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=completion_create_params_1
|
||||
)
|
||||
|
||||
completion_create_params_2 = ChatCompletionTaskParams(
|
||||
model="gpt-4",
|
||||
messages=[ChatCompletionMessage(role="user", content='Tell me a haiku that uses the word "tree".')],
|
||||
stream=True,
|
||||
max_tokens=1000
|
||||
)
|
||||
task2 = ChatCompletionTask(
|
||||
task_id=TASK_2_ID,
|
||||
command_id=COMMAND_2_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=completion_create_params_2
|
||||
)
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceCreated(
|
||||
instance=instance
|
||||
),
|
||||
TaskCreated(
|
||||
task_id=task1.task_id,
|
||||
task=task1
|
||||
),
|
||||
TaskCreated(
|
||||
task_id=task2.task_id,
|
||||
task=task2
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
seen_task_started_1, seen_task_finished_1, response_string_1 = await read_streaming_response(global_events)
|
||||
|
||||
incomplete_task = TASK_2_ID if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.COMPLETE else TASK_2_ID
|
||||
seen_task_started_2, seen_task_finished_2, response_string_2 = await read_streaming_response(global_events, filter_task=incomplete_task)
|
||||
|
||||
assert seen_task_started_1
|
||||
assert seen_task_finished_1
|
||||
assert seen_task_started_2
|
||||
assert seen_task_finished_2
|
||||
|
||||
print(response_string_1)
|
||||
print(response_string_2)
|
||||
|
||||
assert (
|
||||
('pond' in response_string_1.lower()) ^ ('pond' in response_string_2.lower())
|
||||
), "'pond' must appear in exactly one response"
|
||||
assert (
|
||||
('tree' in response_string_1.lower()) ^ ('tree' in response_string_2.lower())
|
||||
), "'tree' must appear in exactly one response"
|
||||
|
||||
|
||||
idx = await global_events.get_last_idx()
|
||||
await asyncio.sleep(1.0)
|
||||
events = await global_events.get_events_since(idx)
|
||||
assert len(events) == 0
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
InstanceDeleted(
|
||||
instance_id=instance.instance_id,
|
||||
),
|
||||
],
|
||||
origin=MASTER_NODE_ID
|
||||
)
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
@@ -1,34 +1,29 @@
|
||||
import asyncio
|
||||
import os
|
||||
from logging import Logger
|
||||
from typing import Callable, Final
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
|
||||
from shared.types.common import Host, NodeId
|
||||
from shared.types.common import Host
|
||||
from shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from shared.types.models import ModelId
|
||||
from shared.types.worker.common import InstanceId, RunnerId
|
||||
from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
|
||||
from shared.types.worker.runners import FailedRunnerStatus
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.download.shard_downloader import NoopShardDownloader
|
||||
from worker.main import run
|
||||
from worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MASTER_NODE_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
)
|
||||
from worker.worker import Worker
|
||||
|
||||
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
|
||||
|
||||
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
|
||||
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
|
||||
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
|
||||
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
|
||||
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
||||
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
||||
TASK_1_ID: Final = "55555555-5555-4555-8555-555555555555"
|
||||
TASK_2_ID: Final = "66666666-6666-4666-8666-666666666666"
|
||||
|
||||
@pytest.fixture
|
||||
def user_message() -> str:
|
||||
@@ -63,7 +58,7 @@ async def check_runner_connection(
|
||||
global_events=global_events,
|
||||
)
|
||||
workers.append(worker1)
|
||||
task1 = asyncio.create_task(run(worker1))
|
||||
task1 = asyncio.create_task(run(worker1, logger))
|
||||
tasks.append(task1)
|
||||
|
||||
worker2 = Worker(
|
||||
@@ -74,7 +69,7 @@ async def check_runner_connection(
|
||||
global_events=global_events,
|
||||
)
|
||||
workers.append(worker2)
|
||||
task2 = asyncio.create_task(run(worker2))
|
||||
task2 = asyncio.create_task(run(worker2, logger))
|
||||
tasks.append(task2)
|
||||
|
||||
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
||||
@@ -152,6 +147,16 @@ async def check_runner_connection(
|
||||
|
||||
# # not now.
|
||||
|
||||
# def test_runner_connection_stress(
|
||||
# logger: Logger,
|
||||
# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
# hosts: Callable[[int], list[Host]],
|
||||
# chat_completion_task: Callable[[InstanceId, str], Task],
|
||||
# ) -> None:
|
||||
# total_runs = 100
|
||||
# successes = 0
|
||||
# # not now.
|
||||
|
||||
# def test_runner_connection_stress(
|
||||
# logger: Logger,
|
||||
# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
@@ -161,11 +166,29 @@ async def check_runner_connection(
|
||||
# total_runs = 100
|
||||
# successes = 0
|
||||
|
||||
# for _ in range(total_runs):
|
||||
# # Create a fresh event loop for each iteration
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# for _ in range(total_runs):
|
||||
# # Create a fresh event loop for each iteration
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(check_runner_connection(
|
||||
# logger=logger,
|
||||
# pipeline_shard_meta=pipeline_shard_meta,
|
||||
# hosts=hosts,
|
||||
# chat_completion_task=chat_completion_task,
|
||||
# ))
|
||||
# if result:
|
||||
# successes += 1
|
||||
# finally:
|
||||
# # Cancel all running tasks
|
||||
# pending = asyncio.all_tasks(loop)
|
||||
# for task in pending:
|
||||
# task.cancel()
|
||||
# try:
|
||||
# result = loop.run_until_complete(check_runner_connection(
|
||||
# logger=logger,
|
||||
@@ -181,10 +204,15 @@ async def check_runner_connection(
|
||||
# for task in pending:
|
||||
# task.cancel()
|
||||
|
||||
# # Run the event loop briefly to allow cancellation to complete
|
||||
# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
# # Run the event loop briefly to allow cancellation to complete
|
||||
# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
|
||||
# # Close the event loop
|
||||
# loop.close()
|
||||
# # Close the event loop
|
||||
# loop.close()
|
||||
|
||||
# print(f"Runner connection successes: {successes} / {total_runs}")
|
||||
# print(f"Runner connection successes: {successes} / {total_runs}")
|
||||
|
||||
60
worker/tests/test_supervisor/test_memory.py
Normal file
60
worker/tests/test_supervisor/test_memory.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from asyncio.subprocess import Process
|
||||
from logging import Logger
|
||||
from typing import Callable
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
|
||||
from shared.models.model_meta import get_model_meta
|
||||
from shared.types.common import Host
|
||||
from shared.types.models import ModelMetadata
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import InstanceId, RunnerError
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
|
||||
|
||||
|
||||
def get_memory_mb(process: Process) -> float:
|
||||
"""
|
||||
Returns the resident set size (RSS) memory usage in MiB for the given process.
|
||||
"""
|
||||
ps = psutil.Process(process.pid)
|
||||
rss_bytes: int = ps.memory_info().rss # type: ignore[attr-defined]
|
||||
return rss_bytes / (1024 * 1024)
|
||||
|
||||
@pytest.fixture
|
||||
async def model_meta() -> ModelMetadata:
|
||||
return await get_model_meta('mlx-community/Llama-3.3-70B-Instruct-4bit')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor_inference_exception(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
process: Process = supervisor.runner_process
|
||||
memory = get_memory_mb(process)
|
||||
assert memory > 30*100
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = 'EXO RUNNER MUST FAIL'
|
||||
with pytest.raises(RunnerError):
|
||||
async for _ in supervisor.stream_response(task):
|
||||
pass
|
||||
|
||||
await supervisor.astop()
|
||||
|
||||
available_memory_bytes: int = psutil.virtual_memory().available
|
||||
print(available_memory_bytes // (2**30))
|
||||
assert available_memory_bytes > 30 * 2**30
|
||||
45
worker/tests/test_supervisor/test_oom.py
Normal file
45
worker/tests/test_supervisor/test_oom.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from logging import Logger
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.types.common import Host
|
||||
from shared.types.tasks import (
|
||||
Task,
|
||||
TaskId,
|
||||
)
|
||||
from shared.types.worker.common import InstanceId, RunnerError
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_message():
|
||||
"""Override the default message to ask about France's capital"""
|
||||
return "What is the capital of France?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor_single_node_response(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = 'EXO RUNNER MUST OOM'
|
||||
with pytest.raises(RunnerError):
|
||||
async for _ in supervisor.stream_response(task):
|
||||
pass
|
||||
|
||||
await supervisor.astop()
|
||||
93
worker/tests/test_supervisor/test_supervisor_sad.py
Normal file
93
worker/tests/test_supervisor/test_supervisor_sad.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.types.common import Host
|
||||
from shared.types.tasks import Task, TaskId
|
||||
from shared.types.worker.common import InstanceId, RunnerError
|
||||
from shared.types.worker.shards import PipelineShardMetadata
|
||||
from worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor_instantiation_exception(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
model_shard_meta.immediate_exception = True
|
||||
|
||||
with pytest.raises(RunnerError):
|
||||
await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor_instantiation_timeout(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
model_shard_meta.should_timeout = 10 # timeout after 10s
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor_inference_exception(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = 'EXO RUNNER MUST FAIL'
|
||||
with pytest.raises(RunnerError):
|
||||
async for _ in supervisor.stream_response(task):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor_inference_timeout(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
logger: Logger,
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
model_shard_meta = pipeline_shard_meta(1, 0)
|
||||
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = 'EXO RUNNER MUST TIMEOUT'
|
||||
with pytest.raises(RunnerError):
|
||||
async for _ in supervisor.stream_response(task):
|
||||
pass
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
@@ -66,6 +67,11 @@ async def start_polling_node_metrics(
|
||||
|
||||
# Run heavy FLOPs profiling only if enough time has elapsed
|
||||
|
||||
override_memory_env = os.getenv('OVERRIDE_MEMORY')
|
||||
override_memory: int | None = (
|
||||
int(override_memory_env) * 2**30 if override_memory_env else None
|
||||
)
|
||||
|
||||
await callback(
|
||||
NodePerformanceProfile(
|
||||
model_id=system_info.model_id,
|
||||
@@ -74,7 +80,7 @@ async def start_polling_node_metrics(
|
||||
network_interfaces=network_interfaces,
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=total_mem,
|
||||
ram_available=total_mem - used_mem,
|
||||
ram_available=override_memory if override_memory else total_mem - used_mem,
|
||||
swap_total=metrics.memory.swap_total
|
||||
if metrics.memory is not None
|
||||
and metrics.memory.swap_total is not None
|
||||
|
||||
313
worker/worker.py
313
worker/worker.py
@@ -3,7 +3,6 @@ import logging
|
||||
import time
|
||||
from asyncio import Queue
|
||||
from functools import partial
|
||||
from time import process_time
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from shared.db.sqlite import AsyncSQLiteEventStorage
|
||||
@@ -22,7 +21,6 @@ from shared.types.tasks import TaskId, TaskStatus
|
||||
from shared.types.worker.common import RunnerId
|
||||
from shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgressData,
|
||||
@@ -71,104 +69,116 @@ class Worker:
|
||||
|
||||
## Op Executors
|
||||
|
||||
async def _execute_assign_op(
|
||||
self, op: AssignRunnerOp
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
'''
|
||||
A runner has been assigned. We need to also ensure that it's downloaded.
|
||||
This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state.
|
||||
'''
|
||||
self.assigned_runners[op.runner_id] = AssignedRunner(
|
||||
def _create_assigned_runner(self, op: AssignRunnerOp) -> AssignedRunner:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
assigned_runner = AssignedRunner(
|
||||
runner_id=op.runner_id,
|
||||
instance_id=op.instance_id,
|
||||
shard_metadata=op.shard_metadata,
|
||||
hosts=op.hosts,
|
||||
status=DownloadingRunnerStatus(
|
||||
download_progress=DownloadPending(
|
||||
node_id=self.node_id
|
||||
)
|
||||
download_progress=DownloadPending(node_id=self.node_id)
|
||||
),
|
||||
runner=None,
|
||||
)
|
||||
self.assigned_runners[op.runner_id] = assigned_runner
|
||||
return assigned_runner
|
||||
|
||||
assigned_runner = self.assigned_runners[op.runner_id]
|
||||
async def _update_runner_status_to_completed_then_inactive(
|
||||
self, assigned_runner: AssignedRunner
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Updates runner status from downloading to completed, then to inactive."""
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadCompleted(node_id=self.node_id)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
assigned_runner.status = InactiveRunnerStatus()
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
async def _handle_already_downloaded_shard(
|
||||
self, assigned_runner: AssignedRunner
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Handles the case where the shard is already downloaded."""
|
||||
async for event in self._update_runner_status_to_completed_then_inactive(assigned_runner):
|
||||
yield event
|
||||
|
||||
async def _handle_shard_download_process(
|
||||
self, assigned_runner: AssignedRunner, op: AssignRunnerOp, initial_progress: RepoDownloadProgress
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Manages the shard download process with progress tracking."""
|
||||
# Set initial ongoing status
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
downloaded_bytes=initial_progress.downloaded_bytes
|
||||
)
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
# Set up download progress tracking
|
||||
download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue()
|
||||
|
||||
def download_progress_callback(shard: ShardMetadata, progress: RepoDownloadProgress) -> None:
|
||||
download_progress_queue.put_nowait(progress)
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
download_task = asyncio.create_task(self.shard_downloader.ensure_shard(op.shard_metadata))
|
||||
|
||||
try:
|
||||
async for event in self._monitor_download_progress(assigned_runner, download_progress_queue):
|
||||
yield event
|
||||
finally:
|
||||
if not download_task.done():
|
||||
download_task.cancel()
|
||||
|
||||
async def _monitor_download_progress(
|
||||
self, assigned_runner: AssignedRunner, download_progress_queue: asyncio.Queue[RepoDownloadProgress]
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Monitors download progress and yields status updates."""
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
while True:
|
||||
progress: RepoDownloadProgress = await asyncio.wait_for(download_progress_queue.get(), timeout=15)
|
||||
|
||||
if progress.status == "complete":
|
||||
async for event in self._update_runner_status_to_completed_then_inactive(assigned_runner):
|
||||
yield event
|
||||
break
|
||||
elif progress.status == "in_progress":
|
||||
if time.monotonic() - last_progress_time > throttle_interval_secs:
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=progress.total_bytes,
|
||||
downloaded_bytes=progress.downloaded_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
last_progress_time = time.monotonic()
|
||||
|
||||
async def _execute_assign_op(
|
||||
self, op: AssignRunnerOp
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""
|
||||
A runner has been assigned. We need to also ensure that it's downloaded.
|
||||
This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state.
|
||||
"""
|
||||
assigned_runner = self._create_assigned_runner(op)
|
||||
initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadCompleted(
|
||||
node_id=self.node_id
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
assigned_runner.status = InactiveRunnerStatus()
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
return
|
||||
async for event in self._handle_already_downloaded_shard(assigned_runner):
|
||||
yield event
|
||||
else:
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
downloaded_bytes=initial_progress.downloaded_bytes
|
||||
)
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
# Download it!
|
||||
# TODO: we probably want download progress as part of a callback that gets passed to the downloader.
|
||||
download_progress_queue: asyncio.Queue[RepoDownloadProgress] = asyncio.Queue()
|
||||
def download_progress_callback(shard: ShardMetadata, progress: RepoDownloadProgress) -> None:
|
||||
download_progress_queue.put_nowait(progress)
|
||||
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
asyncio.create_task(self.shard_downloader.ensure_shard(op.shard_metadata))
|
||||
|
||||
# TODO: Dynamic timeout, timeout on no packet update received.
|
||||
timeout_secs = 10 * 60
|
||||
start_time = process_time()
|
||||
last_yield_progress = start_time
|
||||
while process_time() - start_time < timeout_secs:
|
||||
progress: RepoDownloadProgress = await download_progress_queue.get()
|
||||
if progress.status == "complete":
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
assigned_runner.status = InactiveRunnerStatus()
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
break
|
||||
elif progress.status == "in_progress":
|
||||
if process_time() - last_yield_progress > 1:
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=progress.total_bytes,
|
||||
downloaded_bytes=progress.downloaded_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
last_yield_progress = process_time()
|
||||
else:
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadFailed(
|
||||
node_id=self.node_id,
|
||||
error_message=f"Timeout downloading model: {op.shard_metadata.model_meta.model_id}"
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
async for event in self._handle_shard_download_process(assigned_runner, op, initial_progress):
|
||||
yield event
|
||||
|
||||
async def _execute_unassign_op(
|
||||
self, op: UnassignRunnerOp
|
||||
@@ -193,39 +203,32 @@ class Worker:
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
assigned_runner = self.assigned_runners[op.runner_id]
|
||||
|
||||
# TODO: This should be dynamic, based on the size of the model.
|
||||
if not initialize_timeout:
|
||||
gigabytes_per_second = 10
|
||||
kilobytes_per_second = gigabytes_per_second * 1024 * 1024
|
||||
|
||||
shard = assigned_runner.shard_metadata
|
||||
weights_size_kb = (shard.end_layer - shard.start_layer) / shard.n_layers * shard.model_meta.storage_size_kilobytes
|
||||
|
||||
initialize_timeout = weights_size_kb / kilobytes_per_second + 120.0 # Add a constant 120.0 to ensure connection can be made as well
|
||||
|
||||
self.logger.info(f"initialize_timeout: {initialize_timeout}")
|
||||
|
||||
try:
|
||||
assigned_runner.runner = await asyncio.wait_for(
|
||||
RunnerSupervisor.create(
|
||||
model_shard_meta=assigned_runner.shard_metadata,
|
||||
hosts=assigned_runner.hosts,
|
||||
logger=self.logger,
|
||||
),
|
||||
timeout=initialize_timeout,
|
||||
)
|
||||
except TimeoutError as e:
|
||||
import traceback
|
||||
|
||||
tb = traceback.format_exc()
|
||||
e = Exception(f"{type(e).__name__}: {str(e)}. Traceback: {tb}")
|
||||
async for event in self._fail_runner(e=e, runner_id=op.runner_id):
|
||||
yield event
|
||||
return
|
||||
assigned_runner.runner = await RunnerSupervisor.create(
|
||||
model_shard_meta=assigned_runner.shard_metadata,
|
||||
hosts=assigned_runner.hosts,
|
||||
logger=self.logger,
|
||||
initialize_timeout=initialize_timeout
|
||||
)
|
||||
|
||||
if assigned_runner.runner.healthy:
|
||||
assigned_runner.status = LoadedRunnerStatus()
|
||||
else:
|
||||
# Log detailed reasons why the runner is not healthy
|
||||
runner = assigned_runner.runner
|
||||
health_issues: list[str] = []
|
||||
|
||||
if not runner.running:
|
||||
health_issues.append("runner.running is False")
|
||||
if runner.runner_process.returncode is not None:
|
||||
health_issues.append(f"runner_process.returncode is {runner.runner_process.returncode}")
|
||||
if runner.runner_process.stdin is None:
|
||||
health_issues.append("runner_process.stdin is None")
|
||||
elif runner.runner_process.stdin.is_closing():
|
||||
health_issues.append("runner_process.stdin is closing")
|
||||
if runner.runner_process.stdout is None:
|
||||
health_issues.append("runner_process.stdout is None")
|
||||
|
||||
self.logger.warning(f"Runner status is not healthy: {', '.join(health_issues)}")
|
||||
assigned_runner.status = FailedRunnerStatus()
|
||||
yield self.assigned_runners[op.runner_id].status_update_event()
|
||||
|
||||
@@ -251,6 +254,9 @@ class Worker:
|
||||
'''
|
||||
assigned_runner = self.assigned_runners[op.runner_id]
|
||||
|
||||
if isinstance(assigned_runner.runner, RunnerSupervisor):
|
||||
await assigned_runner.runner.astop() # astop the runner to ensure it clears out of memory.
|
||||
|
||||
assigned_runner.status = FailedRunnerStatus()
|
||||
yield self.assigned_runners[op.runner_id].status_update_event()
|
||||
|
||||
@@ -280,37 +286,30 @@ class Worker:
|
||||
task_status=TaskStatus.RUNNING,
|
||||
))
|
||||
|
||||
try:
|
||||
assert assigned_runner.runner is not None
|
||||
assert assigned_runner.runner.healthy
|
||||
|
||||
async for chunk in assigned_runner.runner.stream_response(
|
||||
task=op.task,
|
||||
request_started_callback=partial(running_callback, queue)):
|
||||
if assigned_runner.shard_metadata.device_rank == 0:
|
||||
await queue.put(ChunkGenerated(
|
||||
# todo: at some point we will no longer have a bijection between task_id and row_id.
|
||||
# So we probably want to store a mapping between these two in our Worker object.
|
||||
command_id=chunk.command_id,
|
||||
chunk=chunk
|
||||
))
|
||||
assert assigned_runner.runner is not None
|
||||
assert assigned_runner.runner.healthy
|
||||
|
||||
async for chunk in assigned_runner.runner.stream_response(
|
||||
task=op.task,
|
||||
request_started_callback=partial(running_callback, queue)):
|
||||
if assigned_runner.shard_metadata.device_rank == 0:
|
||||
await queue.put(TaskStateUpdated(
|
||||
task_id=op.task.task_id,
|
||||
task_status=TaskStatus.COMPLETE,
|
||||
await queue.put(ChunkGenerated(
|
||||
# todo: at some point we will no longer have a bijection between task_id and row_id.
|
||||
# So we probably want to store a mapping between these two in our Worker object.
|
||||
command_id=chunk.command_id,
|
||||
chunk=chunk
|
||||
))
|
||||
|
||||
# After a successful inference:
|
||||
assigned_runner.status = LoadedRunnerStatus()
|
||||
await queue.put(assigned_runner.status_update_event())
|
||||
if assigned_runner.shard_metadata.device_rank == 0:
|
||||
await queue.put(TaskStateUpdated(
|
||||
task_id=op.task.task_id,
|
||||
task_status=TaskStatus.COMPLETE,
|
||||
))
|
||||
|
||||
# After a successful inference:
|
||||
assigned_runner.status = LoadedRunnerStatus()
|
||||
await queue.put(assigned_runner.status_update_event())
|
||||
|
||||
except Exception as e:
|
||||
# An exception occurs in the runner supervisor
|
||||
self.logger.warning(f'Runner failed whilst running inference task. Task: {op.task}. Error: {e}')
|
||||
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
|
||||
await queue.put(event)
|
||||
|
||||
queue: Queue[Event] = asyncio.Queue()
|
||||
task = asyncio.create_task(inner_execute(queue))
|
||||
@@ -320,31 +319,31 @@ class Worker:
|
||||
|
||||
try:
|
||||
# Yield items from the queue
|
||||
# timeout = 30.
|
||||
timeout = 3.
|
||||
while True:
|
||||
item: Event = await asyncio.wait_for(queue.get(), timeout=timeout)
|
||||
if task.done() and (exception := task.exception()):
|
||||
raise exception
|
||||
|
||||
try:
|
||||
# Use a timeout to periodically check task status
|
||||
item: Event = await asyncio.wait_for(queue.get(), timeout=0.01)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
yield item
|
||||
timeout = 2.
|
||||
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
||||
item.runner_status, (LoadedRunnerStatus, FailedRunnerStatus)
|
||||
):
|
||||
if isinstance(item.runner_status, LoadedRunnerStatus):
|
||||
assigned_runner.failures = []
|
||||
|
||||
|
||||
break
|
||||
except TimeoutError as e:
|
||||
# Runner supervisor doesn't respond in time; so we put the runner & task into a failed state
|
||||
self.logger.warning(f'Timed out waiting for runner response to inference task. Task: {op.task}.')
|
||||
async for event in self._fail_task(e, op.runner_id, op.task.task_id):
|
||||
yield event
|
||||
finally:
|
||||
# Ensure the task is cleaned up
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Timed out waiting for task cleanup after inference execution.")
|
||||
|
||||
|
||||
|
||||
## Operation Planner
|
||||
|
||||
@@ -368,7 +367,7 @@ class Worker:
|
||||
yield event
|
||||
|
||||
|
||||
async def _fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]:
|
||||
async def fail_runner(self, e: Exception, runner_id: RunnerId) -> AsyncGenerator[Event]:
|
||||
if runner_id in self.assigned_runners:
|
||||
assigned_runner = self.assigned_runners[runner_id]
|
||||
|
||||
@@ -383,15 +382,15 @@ class Worker:
|
||||
|
||||
# Reset failure count back to 0 when succesful
|
||||
if len(assigned_runner.failures) >= 3:
|
||||
# Too many retries. We will emit a DeleteInstance
|
||||
# Too many retries. We will emit a DeleteInstance
|
||||
yield InstanceDeleted(
|
||||
instance_id=assigned_runner.instance_id
|
||||
)
|
||||
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
|
||||
async def _fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]:
|
||||
|
||||
async def fail_task(self, e: Exception, runner_id: RunnerId, task_id: TaskId) -> AsyncGenerator[Event]:
|
||||
if runner_id in self.assigned_runners:
|
||||
yield TaskStateUpdated(
|
||||
task_id=task_id,
|
||||
@@ -404,7 +403,7 @@ class Worker:
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
async for event in self._fail_runner(e, runner_id):
|
||||
async for event in self.fail_runner(e, runner_id):
|
||||
yield event
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user